{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Marking
( SecurityMarking(..), Classified(..), Leak(..), index
, Marking.toNormalForm, unsafeCast
#ifdef ML_KEM_TESTING
, Marking.toList
#endif
) where
import Control.DeepSeq (NFData(..))
import Control.Monad.ST
import Data.ByteArray (Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B
import Data.Kind
import Foreign.Ptr (Ptr)
import Unsafe.Coerce
import Base
import Block (Block, MutableBlock, blockIndex)
import ScrubbedBlock (ScrubbedBlock)
import qualified Block
import qualified ByteArrayST as ST
import qualified ScrubbedBlock
data SecurityMarking = Sec | Pub
class Leak t where
leak :: t Sec -> t Pub
leak = t 'Sec -> t 'Pub
forall a b. a -> b
unsafeCoerce
class Classified (marking :: SecurityMarking) where
type SecureBlock marking = (block :: Type -> Type) | block -> marking
new :: (PrimType ty, PrimMonad prim) => proxy marking -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
thaw :: PrimMonad m => SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)
#ifdef ML_KEM_TESTING
eq :: (Eq ty, PrimType ty) => SecureBlock marking ty -> SecureBlock marking ty -> Bool
showsPrec :: (PrimType ty, Show ty) => Int -> SecureBlock marking ty -> ShowS
lengthBlock :: PrimType ty => SecureBlock marking ty -> CountOf ty
#endif
type SecureBytes marking = bytes | bytes -> marking
unsafeCreate :: Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
lengthBytes :: SecureBytes marking -> Int
copyByteArrayToPtr :: SecureBytes marking -> Ptr a -> IO ()
instance Classified Pub where
type SecureBlock Pub = Block
new :: forall ty (prim :: * -> *) (proxy :: SecurityMarking -> *).
(PrimType ty, PrimMonad prim) =>
proxy 'Pub -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
new proxy 'Pub
_ = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
Block.new
thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
SecureBlock 'Pub ty -> m (MutableBlock ty (PrimState m))
thaw = Block ty -> m (MutableBlock ty (PrimState m))
SecureBlock 'Pub ty -> m (MutableBlock ty (PrimState m))
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
Block.thaw
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Pub ty)
unsafeFreeze = MutableBlock ty (PrimState prim) -> prim (Block ty)
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Pub ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
Block.unsafeFreeze
#ifdef ML_KEM_TESTING
eq = (==)
showsPrec = Prelude.showsPrec
lengthBlock = Block.length
#endif
type SecureBytes Pub = Bytes
unsafeCreate :: forall a. Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Pub
unsafeCreate = Int -> (forall s. Ptr a -> ST s ()) -> Bytes
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Pub
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate
{-# INLINE unsafeCreate #-}
lengthBytes :: SecureBytes 'Pub -> Int
lengthBytes = Bytes -> Int
SecureBytes 'Pub -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length
copyByteArrayToPtr :: forall a. SecureBytes 'Pub -> Ptr a -> IO ()
copyByteArrayToPtr = Bytes -> Ptr a -> IO ()
SecureBytes 'Pub -> Ptr a -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. Bytes -> Ptr p -> IO ()
B.copyByteArrayToPtr
instance Classified Sec where
type SecureBlock Sec = ScrubbedBlock
new :: forall ty (prim :: * -> *) (proxy :: SecurityMarking -> *).
(PrimType ty, PrimMonad prim) =>
proxy 'Sec -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
new proxy 'Sec
_ = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall ty (prim :: * -> *).
(PrimType ty, PrimMonad prim) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
ScrubbedBlock.new
thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
SecureBlock 'Sec ty -> m (MutableBlock ty (PrimState m))
thaw = ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
SecureBlock 'Sec ty -> m (MutableBlock ty (PrimState m))
forall (m :: * -> *) ty.
PrimMonad m =>
ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
ScrubbedBlock.thaw
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Sec ty)
unsafeFreeze = MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Sec ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
ScrubbedBlock.unsafeFreeze
#ifdef ML_KEM_TESTING
eq = (==)
showsPrec = Prelude.showsPrec
lengthBlock = ScrubbedBlock.length
#endif
type SecureBytes Sec = ScrubbedBytes
unsafeCreate :: forall a. Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Sec
unsafeCreate = Int -> (forall s. Ptr a -> ST s ()) -> ScrubbedBytes
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Sec
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate
{-# INLINE unsafeCreate #-}
lengthBytes :: SecureBytes 'Sec -> Int
lengthBytes = ScrubbedBytes -> Int
SecureBytes 'Sec -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length
copyByteArrayToPtr :: forall a. SecureBytes 'Sec -> Ptr a -> IO ()
copyByteArrayToPtr = ScrubbedBytes -> Ptr a -> IO ()
SecureBytes 'Sec -> Ptr a -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. ScrubbedBytes -> Ptr p -> IO ()
B.copyByteArrayToPtr
unwrap :: SecureBlock marking a -> Block a
unwrap :: forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap = SecureBlock marking a -> Block a
forall a b. a -> b
unsafeCoerce
wrap :: Block b -> SecureBlock marking b
wrap :: forall b (marking :: SecurityMarking).
Block b -> SecureBlock marking b
wrap = Block b -> SecureBlock marking b
forall a b. a -> b
unsafeCoerce
index :: PrimType ty => SecureBlock marking ty -> Offset ty -> ty
index :: forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
index = Block ty -> Offset ty -> ty
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex (Block ty -> Offset ty -> ty)
-> (SecureBlock marking ty -> Block ty)
-> SecureBlock marking ty
-> Offset ty
-> ty
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking ty -> Block ty
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap
#ifdef ML_KEM_TESTING
toList :: PrimType ty => SecureBlock marking ty -> [ty]
toList = Block.toList . unwrap
#endif
toNormalForm :: SecureBlock marking ty -> ()
toNormalForm :: forall (marking :: SecurityMarking) ty.
SecureBlock marking ty -> ()
toNormalForm = Block ty -> ()
forall a. NFData a => a -> ()
rnf (Block ty -> ())
-> (SecureBlock marking ty -> Block ty)
-> SecureBlock marking ty
-> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking ty -> Block ty
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap
unsafeCast :: SecureBlock marking a -> SecureBlock marking b
unsafeCast :: forall (marking :: SecurityMarking) a b.
SecureBlock marking a -> SecureBlock marking b
unsafeCast = Block b -> SecureBlock marking b
forall b (marking :: SecurityMarking).
Block b -> SecureBlock marking b
wrap (Block b -> SecureBlock marking b)
-> (SecureBlock marking a -> Block b)
-> SecureBlock marking a
-> SecureBlock marking b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Block a -> Block b
forall a b. Block a -> Block b
Block.unsafeCast (Block a -> Block b)
-> (SecureBlock marking a -> Block a)
-> SecureBlock marking a
-> Block b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking a -> Block a
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap