{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Block
( Block, MutableBlock, blockIndex, blockRead, blockWrite
, foldZipWith, Block.length, mutableContents
, Block.new, Block.newPinned, Block.thaw, thawPinned
, Block.unsafeCast, unsafeCastMut, Block.unsafeFreeze, Block.unsafeThaw
#ifdef ML_KEM_TESTING
, Block.toList
#endif
) where
import Control.Monad.Primitive
import Data.Primitive.ByteArray
import Data.Primitive.PrimArray
import Control.Exception (assert)
import Foreign.Ptr (Ptr)
import Base
type Block = PrimArray
type MutableBlock ty s = MutablePrimArray s ty
blockIndex :: PrimType ty => Block ty -> Offset ty -> ty
blockRead :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockWrite :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
#ifdef ML_KEM_TESTING
blockIndex b off@(Offset i) =
checkBounds (Block.length b) off $ indexPrimArray b i
blockRead mb off@(Offset i) = getSizeofMutablePrimArray mb >>= \sz ->
checkBounds (CountOf sz) off $ readPrimArray mb i
blockWrite mb off@(Offset i) a = getSizeofMutablePrimArray mb >>= \sz ->
checkBounds (CountOf sz) off $ writePrimArray mb i a
#else
blockIndex :: forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block ty
b (Offset Int
i) = Block ty -> Int -> ty
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block ty
b Int
i
blockRead :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockRead MutableBlock ty (PrimState prim)
mb (Offset Int
i) = MutableBlock ty (PrimState prim) -> Int -> prim ty
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutableBlock ty (PrimState prim)
mb Int
i
blockWrite :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
blockWrite MutableBlock ty (PrimState prim)
mb (Offset Int
i) = MutableBlock ty (PrimState prim) -> Int -> ty -> prim ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutableBlock ty (PrimState prim)
mb Int
i
#endif
foldZipWith :: (PrimType a, PrimType b)
=> (c -> a -> b -> c) -> c -> Block a -> Block b -> c
foldZipWith :: forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> Block a -> Block b -> c
foldZipWith c -> a -> b -> c
f c
c Block a
a Block b
b = Bool -> c -> c
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sb) (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$
c -> Int -> c
loop c
c Int
0
where
sa :: Int
sa = Block a -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray Block a
a
sb :: Int
sb = Block b -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray Block b
b
loop :: c -> Int -> c
loop !c
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sa = c
acc
| Bool
otherwise = do
let va :: a
va = Block a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block a
a Int
i
let vb :: b
vb = Block b -> Int -> b
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block b
b Int
i
c -> Int -> c
loop (c -> a -> b -> c
f c
acc a
va b
vb) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE foldZipWith #-}
length :: PrimType ty => Block ty -> CountOf ty
length :: forall ty. PrimType ty => Block ty -> CountOf ty
length = Int -> CountOf ty
forall ty. Int -> CountOf ty
CountOf (Int -> CountOf ty) -> (Block ty -> Int) -> Block ty -> CountOf ty
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Block ty -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray
mutableContents :: MutableBlock ty s -> Ptr ty
mutableContents :: forall ty s. MutableBlock ty s -> Ptr ty
mutableContents = MutablePrimArray s ty -> Ptr ty
forall s a. MutablePrimArray s a -> Ptr a
mutablePrimArrayContents
new :: (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
new (CountOf Int
n) = Int -> prim (MutablePrimArray (PrimState prim) ty)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
n
newPinned :: (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
newPinned :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
newPinned (CountOf Int
n) = Int -> prim (MutablePrimArray (PrimState prim) ty)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPinnedPrimArray Int
n
thaw :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
thaw :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
thaw (PrimArray !ByteArray#
barr) = ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall (m :: * -> *) s a. PrimMonad m => ST s a -> m a
unsafeSTToPrim (ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall a b. (a -> b) -> a -> b
$
ByteArray
-> Int
-> Int
-> ST
(PrimState prim)
(MutableByteArray (PrimState (ST (PrimState prim))))
forall (m :: * -> *).
PrimMonad m =>
ByteArray -> Int -> Int -> m (MutableByteArray (PrimState m))
thawByteArray ByteArray
ba Int
0 (ByteArray -> Int
sizeofByteArray ByteArray
ba) ST (PrimState prim) (MutableByteArray (PrimState prim))
-> (MutableByteArray (PrimState prim)
-> ST (PrimState prim) (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
forall a b.
ST (PrimState prim) a
-> (a -> ST (PrimState prim) b) -> ST (PrimState prim) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(MutableByteArray MutableByteArray# (PrimState prim)
mbarr) ->
MutableBlock ty (PrimState prim)
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
forall a. a -> ST (PrimState prim) a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray# (PrimState prim)
-> MutableBlock ty (PrimState prim)
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# (PrimState prim)
mbarr)
where ba :: ByteArray
ba = ByteArray# -> ByteArray
ByteArray ByteArray#
barr
thawPinned :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
thawPinned :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
thawPinned (PrimArray !ByteArray#
barr) = ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall (m :: * -> *) s a. PrimMonad m => ST s a -> m a
unsafeSTToPrim (ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall a b. (a -> b) -> a -> b
$ do
let ba :: ByteArray
ba = ByteArray# -> ByteArray
ByteArray ByteArray#
barr
n :: Int
n = ByteArray -> Int
sizeofByteArray ByteArray
ba
mb@(MutableByteArray mbarr) <- Int
-> ST
(PrimState prim)
(MutableByteArray (PrimState (ST (PrimState prim))))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
n
copyByteArray mb 0 ba 0 n
return (MutablePrimArray mbarr)
#ifdef ML_KEM_TESTING
toList :: PrimType ty => Block ty -> [ty]
toList = primArrayToList
#endif
unsafeCast :: Block a -> Block b
unsafeCast :: forall a b. Block a -> Block b
unsafeCast (PrimArray ByteArray#
b) = ByteArray# -> PrimArray b
forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
b
unsafeCastMut :: MutableBlock a m -> MutableBlock b m
unsafeCastMut :: forall a m b. MutableBlock a m -> MutableBlock b m
unsafeCastMut (MutablePrimArray MutableByteArray# m
mb) = MutableByteArray# m -> MutablePrimArray m b
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# m
mb
unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreeze = MutablePrimArray (PrimState prim) ty -> prim (PrimArray ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreezePrimArray
unsafeThaw :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThaw :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThaw = PrimArray ty -> prim (MutablePrimArray (PrimState prim) ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThawPrimArray