-- |
-- Module      : ScrubbedBlock
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A block that is always pinned in memory and automatically erased by a
-- finalizer when not referenced anymore.  Same pattern as ScrubbedBytes from
-- package memory but for blocks.
--
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module ScrubbedBlock
    ( ScrubbedBlock, foldZipWith, ScrubbedBlock.length
    , new, thaw, unsafeFreeze
    ) where

import Data.Primitive.PrimArray as Block

import Control.Exception (assert)
import Control.Monad.ST

import Data.Word

import Base
import Block (Block, MutableBlock)
import qualified Block

import GHC.Base (IO(IO), Int(I#), setByteArray#)
import GHC.Exts (mkWeak#)

newtype ScrubbedBlock ty = ScrubbedBlock (Block ty)
    deriving (ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
(ScrubbedBlock ty -> ScrubbedBlock ty -> Bool)
-> (ScrubbedBlock ty -> ScrubbedBlock ty -> Bool)
-> Eq (ScrubbedBlock ty)
forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
== :: ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
$c/= :: forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
/= :: ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
Eq, Int -> ScrubbedBlock ty -> ShowS
[ScrubbedBlock ty] -> ShowS
ScrubbedBlock ty -> String
(Int -> ScrubbedBlock ty -> ShowS)
-> (ScrubbedBlock ty -> String)
-> ([ScrubbedBlock ty] -> ShowS)
-> Show (ScrubbedBlock ty)
forall ty. (Show ty, Prim ty) => Int -> ScrubbedBlock ty -> ShowS
forall ty. (Show ty, Prim ty) => [ScrubbedBlock ty] -> ShowS
forall ty. (Show ty, Prim ty) => ScrubbedBlock ty -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall ty. (Show ty, Prim ty) => Int -> ScrubbedBlock ty -> ShowS
showsPrec :: Int -> ScrubbedBlock ty -> ShowS
$cshow :: forall ty. (Show ty, Prim ty) => ScrubbedBlock ty -> String
show :: ScrubbedBlock ty -> String
$cshowList :: forall ty. (Show ty, Prim ty) => [ScrubbedBlock ty] -> ShowS
showList :: [ScrubbedBlock ty] -> ShowS
Show)

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith :: forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith c -> a -> b -> c
f c
c (ScrubbedBlock Block a
a) (ScrubbedBlock Block b
b) =
    (c -> a -> b -> c) -> c -> Block a -> Block b -> c
forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> Block a -> Block b -> c
Block.foldZipWith c -> a -> b -> c
f c
c Block a
a Block b
b
{-# INLINE foldZipWith #-}

length :: PrimType ty => ScrubbedBlock ty -> CountOf ty
length :: forall ty. PrimType ty => ScrubbedBlock ty -> CountOf ty
length (ScrubbedBlock Block ty
b) = Block ty -> CountOf ty
forall ty. PrimType ty => Block ty -> CountOf ty
Block.length Block ty
b

new :: (PrimType ty, PrimMonad prim) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new :: forall ty (prim :: * -> *).
(PrimType ty, PrimMonad prim) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
new = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
Block.newPinned  -- always pinned

thaw :: PrimMonad m => ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw (ScrubbedBlock Block ty
b) = Block ty -> m (MutablePrimArray (PrimState m) ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
Block.thawPinned Block ty
b  -- always pinned

unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze MutableBlock ty (PrimState prim)
mb = MutableBlock ty (PrimState prim) -> prim (Block ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
Block.unsafeFreeze MutableBlock ty (PrimState prim)
mb prim (Block ty)
-> (Block ty -> prim (ScrubbedBlock ty)) -> prim (ScrubbedBlock ty)
forall a b. prim a -> (a -> prim b) -> prim b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Block ty -> prim (ScrubbedBlock ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (ScrubbedBlock ty)
scrubbed


{- internal -}

assertPinned :: Block ty -> a -> a
assertPinned :: forall ty a. Block ty -> a -> a
assertPinned Block ty
mb = Bool -> a -> a
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Block ty -> Bool
forall a. PrimArray a -> Bool
Block.isPrimArrayPinned Block ty
mb)

scrubbed :: PrimMonad prim => Block ty -> prim (ScrubbedBlock ty)
scrubbed :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (ScrubbedBlock ty)
scrubbed Block ty
b = Block ty -> prim (ScrubbedBlock ty) -> prim (ScrubbedBlock ty)
forall ty a. Block ty -> a -> a
assertPinned Block ty
b (prim (ScrubbedBlock ty) -> prim (ScrubbedBlock ty))
-> prim (ScrubbedBlock ty) -> prim (ScrubbedBlock ty)
forall a b. (a -> b) -> a -> b
$ IO (ScrubbedBlock ty) -> prim (ScrubbedBlock ty)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafePrimFromIO (IO (ScrubbedBlock ty) -> prim (ScrubbedBlock ty))
-> IO (ScrubbedBlock ty) -> prim (ScrubbedBlock ty)
forall a b. (a -> b) -> a -> b
$
    Block ty -> IO ()
forall ty. Block ty -> IO ()
scheduleBlockScrubbing Block ty
b IO () -> IO (ScrubbedBlock ty) -> IO (ScrubbedBlock ty)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ScrubbedBlock ty -> IO (ScrubbedBlock ty)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Block ty -> ScrubbedBlock ty
forall ty. Block ty -> ScrubbedBlock ty
ScrubbedBlock Block ty
b)

scheduleBlockScrubbing :: Block ty -> IO ()
scheduleBlockScrubbing :: forall ty. Block ty -> IO ()
scheduleBlockScrubbing Block ty
b = Block ty -> IO () -> IO ()
forall ty. Block ty -> IO () -> IO ()
addBlockFinalizer Block ty
b (Block Word8 -> IO ()
scrub (Block Word8 -> IO ()) -> Block Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Block ty -> Block Word8
forall a b. Block a -> Block b
Block.unsafeCast Block ty
b)
{-# NOINLINE scheduleBlockScrubbing #-}

scrub :: Block Word8 -> IO ()
scrub :: Block Word8 -> IO ()
scrub Block Word8
b = Block Word8 -> IO (MutableBlock Word8 (PrimState IO))
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
Block.unsafeThaw Block Word8
b IO (MutablePrimArray RealWorld Word8)
-> (MutablePrimArray RealWorld Word8 -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> MutablePrimArray RealWorld Word8 -> IO ()
erase Int
len
  where CountOf Int
len = Block Word8 -> CountOf Word8
forall ty. PrimType ty => Block ty -> CountOf ty
Block.length Block Word8
b

addBlockFinalizer :: Block ty -> IO () -> IO ()
addBlockFinalizer :: forall ty. Block ty -> IO () -> IO ()
addBlockFinalizer (Block.PrimArray ByteArray#
barr) (IO State# RealWorld -> (# State# RealWorld, () #)
finalizer) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
   case ByteArray#
-> ()
-> (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld
-> (# State# RealWorld, Weak# () #)
forall a b c.
a
-> b
-> (State# RealWorld -> (# State# RealWorld, c #))
-> State# RealWorld
-> (# State# RealWorld, Weak# b #)
mkWeak# ByteArray#
barr () State# RealWorld -> (# State# RealWorld, () #)
finalizer State# RealWorld
s of { (# State# RealWorld
s1, Weak# ()
_ #) -> (# State# RealWorld
s1, () #) }

erase :: Int -> MutableBlock Word8 RealWorld -> IO ()
erase :: Int -> MutablePrimArray RealWorld Word8 -> IO ()
erase (I# Int#
len) (Block.MutablePrimArray MutableByteArray# RealWorld
mbarr) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s1 ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
MutableByteArray# d -> Int# -> Int# -> Int# -> State# d -> State# d
setByteArray# MutableByteArray# RealWorld
mbarr Int#
0# Int#
len Int#
0# State# RealWorld
s1 of
        State# RealWorld
s2 -> (# State# RealWorld
s2, () #)