-- |
-- Module      : ByteArrayST
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Byte array primitives in the @ST@ monad instead of @IO@
--
{-# LANGUAGE RankNTypes #-}
module ByteArrayST
    ( unsafeCreate, withByteArray
    , peek, peekElemOff, pokeElemOff, pokeByteOff
    ) where

import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B

import Control.Monad.ST
import Control.Monad.ST.Unsafe

import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)
import qualified Foreign.Storable as S

unsafeCreate :: ByteArray ba => Int -> (forall s. Ptr p -> ST s ()) -> ba
unsafeCreate :: forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
unsafeCreate Int
sz forall s. Ptr p -> ST s ()
f = Int -> (Ptr p -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
sz (ST RealWorld () -> IO ()
forall a. ST RealWorld a -> IO a
stToIO (ST RealWorld () -> IO ())
-> (Ptr p -> ST RealWorld ()) -> Ptr p -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr p -> ST RealWorld ()
forall s. Ptr p -> ST s ()
f)
{-# INLINE unsafeCreate #-}

withByteArray :: ByteArrayAccess ba => ba -> (Ptr p -> ST s a) -> ST s a
withByteArray :: forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
withByteArray ba
b Ptr p -> ST s a
f = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> IO a -> ST s a
forall a b. (a -> b) -> a -> b
$ ba -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
b (ST s a -> IO a
forall s a. ST s a -> IO a
unsafeSTToIO (ST s a -> IO a) -> (Ptr p -> ST s a) -> Ptr p -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr p -> ST s a
f)
{-# INLINE withByteArray #-}

peek :: Storable a => Ptr a -> ST s a
peek :: forall a s. Storable a => Ptr a -> ST s a
peek = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> (Ptr a -> IO a) -> Ptr a -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
S.peek

peekElemOff :: Storable a => Ptr a -> Int -> ST s a
peekElemOff :: forall a s. Storable a => Ptr a -> Int -> ST s a
peekElemOff Ptr a
a = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> (Int -> IO a) -> Int -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
S.peekElemOff Ptr a
a

pokeElemOff :: Storable a => Ptr a -> Int -> a -> ST s ()
pokeElemOff :: forall a s. Storable a => Ptr a -> Int -> a -> ST s ()
pokeElemOff Ptr a
a Int
off = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> (a -> IO ()) -> a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
S.pokeElemOff Ptr a
a Int
off

pokeByteOff :: Storable a => Ptr a -> Int -> a -> ST s ()
pokeByteOff :: forall a s. Storable a => Ptr a -> Int -> a -> ST s ()
pokeByteOff Ptr a
a Int
off = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> (a -> IO ()) -> a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Int -> a -> IO ()
forall b. Ptr b -> Int -> a -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
S.pokeByteOff Ptr a
a Int
off