-- |
-- Module      : Block
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- An array of primitive (unlifted) elements.  This module exposes the
-- t'PrimArray' implementation from primitive but through an API similar to
-- basement @Block@.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Block
    ( Block, MutableBlock, blockIndex, blockRead, blockWrite
    , foldZipWith, Block.length, mutableContents
    , Block.new, Block.newPinned, Block.thaw, thawPinned
    , Block.unsafeCast, unsafeCastMut, Block.unsafeFreeze, Block.unsafeThaw
#ifdef ML_KEM_TESTING
    , Block.toList
#endif
    ) where

import Control.Monad.Primitive

import Data.Primitive.ByteArray
import Data.Primitive.PrimArray

import Control.Exception (assert)

import Foreign.Ptr (Ptr)

import Base

type Block = PrimArray
type MutableBlock ty s = MutablePrimArray s ty

blockIndex :: PrimType ty => Block ty -> Offset ty -> ty
blockRead :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockWrite :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
#ifdef ML_KEM_TESTING
blockIndex b off@(Offset i) =
    checkBounds (Block.length b) off $ indexPrimArray b i
blockRead mb off@(Offset i) = getSizeofMutablePrimArray mb >>= \sz ->
    checkBounds (CountOf sz) off $ readPrimArray mb i
blockWrite mb off@(Offset i) a = getSizeofMutablePrimArray mb >>= \sz ->
    checkBounds (CountOf sz) off $ writePrimArray mb i a
#else
blockIndex :: forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block ty
b (Offset Int
i) = Block ty -> Int -> ty
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block ty
b Int
i
blockRead :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockRead MutableBlock ty (PrimState prim)
mb (Offset Int
i) = MutableBlock ty (PrimState prim) -> Int -> prim ty
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutableBlock ty (PrimState prim)
mb Int
i
blockWrite :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
blockWrite MutableBlock ty (PrimState prim)
mb (Offset Int
i) = MutableBlock ty (PrimState prim) -> Int -> ty -> prim ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutableBlock ty (PrimState prim)
mb Int
i
#endif

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> Block a -> Block b -> c
foldZipWith :: forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> Block a -> Block b -> c
foldZipWith c -> a -> b -> c
f c
c Block a
a Block b
b = Bool -> c -> c
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sb) (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$
    c -> Int -> c
loop c
c Int
0
  where
    sa :: Int
sa = Block a -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray Block a
a
    sb :: Int
sb = Block b -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray Block b
b

    loop :: c -> Int -> c
loop !c
acc Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sa = c
acc
        | Bool
otherwise = do
            let va :: a
va = Block a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block a
a Int
i
            let vb :: b
vb = Block b -> Int -> b
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray Block b
b Int
i
            c -> Int -> c
loop (c -> a -> b -> c
f c
acc a
va b
vb) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE foldZipWith #-}

length :: PrimType ty => Block ty -> CountOf ty
length :: forall ty. PrimType ty => Block ty -> CountOf ty
length = Int -> CountOf ty
forall ty. Int -> CountOf ty
CountOf (Int -> CountOf ty) -> (Block ty -> Int) -> Block ty -> CountOf ty
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Block ty -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray

mutableContents :: MutableBlock ty s -> Ptr ty
mutableContents :: forall ty s. MutableBlock ty s -> Ptr ty
mutableContents = MutablePrimArray s ty -> Ptr ty
forall s a. MutablePrimArray s a -> Ptr a
mutablePrimArrayContents -- pinned only

new :: (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
new (CountOf Int
n) = Int -> prim (MutablePrimArray (PrimState prim) ty)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
n

newPinned :: (PrimMonad prim, PrimType ty) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
newPinned :: forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
newPinned (CountOf Int
n) = Int -> prim (MutablePrimArray (PrimState prim) ty)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPinnedPrimArray Int
n

thaw :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
thaw :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
thaw (PrimArray !ByteArray#
barr) = ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall (m :: * -> *) s a. PrimMonad m => ST s a -> m a
unsafeSTToPrim (ST (PrimState prim) (MutableBlock ty (PrimState prim))
 -> prim (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall a b. (a -> b) -> a -> b
$
    -- as optimization, combine both steps in a known monad and avoid
    -- round trip between byte length and element count
    ByteArray
-> Int
-> Int
-> ST
     (PrimState prim)
     (MutableByteArray (PrimState (ST (PrimState prim))))
forall (m :: * -> *).
PrimMonad m =>
ByteArray -> Int -> Int -> m (MutableByteArray (PrimState m))
thawByteArray ByteArray
ba Int
0 (ByteArray -> Int
sizeofByteArray ByteArray
ba) ST (PrimState prim) (MutableByteArray (PrimState prim))
-> (MutableByteArray (PrimState prim)
    -> ST (PrimState prim) (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
forall a b.
ST (PrimState prim) a
-> (a -> ST (PrimState prim) b) -> ST (PrimState prim) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(MutableByteArray MutableByteArray# (PrimState prim)
mbarr) ->
        MutableBlock ty (PrimState prim)
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
forall a. a -> ST (PrimState prim) a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray# (PrimState prim)
-> MutableBlock ty (PrimState prim)
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# (PrimState prim)
mbarr)
  where ba :: ByteArray
ba = ByteArray# -> ByteArray
ByteArray ByteArray#
barr

thawPinned :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
thawPinned :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
thawPinned (PrimArray !ByteArray#
barr) = ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall (m :: * -> *) s a. PrimMonad m => ST s a -> m a
unsafeSTToPrim (ST (PrimState prim) (MutableBlock ty (PrimState prim))
 -> prim (MutableBlock ty (PrimState prim)))
-> ST (PrimState prim) (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall a b. (a -> b) -> a -> b
$ do
    let ba :: ByteArray
ba = ByteArray# -> ByteArray
ByteArray ByteArray#
barr
        n :: Int
n = ByteArray -> Int
sizeofByteArray ByteArray
ba
    mb@(MutableByteArray mbarr) <- Int
-> ST
     (PrimState prim)
     (MutableByteArray (PrimState (ST (PrimState prim))))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
n
    copyByteArray mb 0 ba 0 n
    return (MutablePrimArray mbarr)

#ifdef ML_KEM_TESTING
toList :: PrimType ty => Block ty -> [ty]
toList = primArrayToList
#endif

unsafeCast :: Block a -> Block b
unsafeCast :: forall a b. Block a -> Block b
unsafeCast (PrimArray ByteArray#
b) = ByteArray# -> PrimArray b
forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
b

unsafeCastMut :: MutableBlock a m -> MutableBlock b m
unsafeCastMut :: forall a m b. MutableBlock a m -> MutableBlock b m
unsafeCastMut (MutablePrimArray MutableByteArray# m
mb) = MutableByteArray# m -> MutablePrimArray m b
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# m
mb

unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreeze = MutablePrimArray (PrimState prim) ty -> prim (PrimArray ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
unsafeFreezePrimArray

unsafeThaw :: PrimMonad prim => Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThaw :: forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThaw = PrimArray ty -> prim (MutablePrimArray (PrimState prim) ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
unsafeThawPrimArray