{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
module Crypto.PubKey.ML_KEM
( EncapsulationKey, DecapsulationKey, Ciphertext, SharedSecret
, generate, generateOpen, generateWith, encapsulate, encapsulateWith
, decapsulate
, ParamSet, ML_KEM_512, ML_KEM_768, ML_KEM_1024
, Decode(..), Encode(..)
, toPublic, checkKeyPair
) where
import Crypto.Random
import Data.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes)
import qualified Data.ByteArray as B
import Internal
data ML_KEM_512 = ML_KEM_512 deriving Int -> ML_KEM_512 -> ShowS
[ML_KEM_512] -> ShowS
ML_KEM_512 -> String
(Int -> ML_KEM_512 -> ShowS)
-> (ML_KEM_512 -> String)
-> ([ML_KEM_512] -> ShowS)
-> Show ML_KEM_512
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_512 -> ShowS
showsPrec :: Int -> ML_KEM_512 -> ShowS
$cshow :: ML_KEM_512 -> String
show :: ML_KEM_512 -> String
$cshowList :: [ML_KEM_512] -> ShowS
showList :: [ML_KEM_512] -> ShowS
Show
data ML_KEM_768 = ML_KEM_768 deriving Int -> ML_KEM_768 -> ShowS
[ML_KEM_768] -> ShowS
ML_KEM_768 -> String
(Int -> ML_KEM_768 -> ShowS)
-> (ML_KEM_768 -> String)
-> ([ML_KEM_768] -> ShowS)
-> Show ML_KEM_768
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_768 -> ShowS
showsPrec :: Int -> ML_KEM_768 -> ShowS
$cshow :: ML_KEM_768 -> String
show :: ML_KEM_768 -> String
$cshowList :: [ML_KEM_768] -> ShowS
showList :: [ML_KEM_768] -> ShowS
Show
data ML_KEM_1024 = ML_KEM_1024 deriving Int -> ML_KEM_1024 -> ShowS
[ML_KEM_1024] -> ShowS
ML_KEM_1024 -> String
(Int -> ML_KEM_1024 -> ShowS)
-> (ML_KEM_1024 -> String)
-> ([ML_KEM_1024] -> ShowS)
-> Show ML_KEM_1024
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_1024 -> ShowS
showsPrec :: Int -> ML_KEM_1024 -> ShowS
$cshow :: ML_KEM_1024 -> String
show :: ML_KEM_1024 -> String
$cshowList :: [ML_KEM_1024] -> ShowS
showList :: [ML_KEM_1024] -> ShowS
Show
instance ParamSet ML_KEM_512 where
type K ML_KEM_512 = 2
getParams :: forall (proxy :: * -> *). proxy ML_KEM_512 -> Params (K ML_KEM_512)
getParams proxy ML_KEM_512
_ = Word -> Word -> Int -> Int -> Params 2
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
3 Word
2 Int
10 Int
4
instance ParamSet ML_KEM_768 where
type K ML_KEM_768 = 3
getParams :: forall (proxy :: * -> *). proxy ML_KEM_768 -> Params (K ML_KEM_768)
getParams proxy ML_KEM_768
_ = Word -> Word -> Int -> Int -> Params 3
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
2 Word
2 Int
10 Int
4
instance ParamSet ML_KEM_1024 where
type K ML_KEM_1024 = 4
getParams :: forall (proxy :: * -> *).
proxy ML_KEM_1024 -> Params (K ML_KEM_1024)
getParams proxy ML_KEM_1024
_ = Word -> Word -> Int -> Int -> Params 4
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
2 Word
2 Int
11 Int
5
generate :: (ParamSet a, MonadRandom m)
=> proxy a -> m (EncapsulationKey a, DecapsulationKey a)
generate :: forall a (m :: * -> *) (proxy :: * -> *).
(ParamSet a, MonadRandom m) =>
proxy a -> m (EncapsulationKey a, DecapsulationKey a)
generate proxy a
p = do
seed <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64
let d = ScrubbedBytes -> Int -> View ScrubbedBytes
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
B.takeView ScrubbedBytes
seed Int
32
z = Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
B.drop Int
32 ScrubbedBytes
seed
return (Internal.keyGen p d z)
generateOpen :: (ParamSet a, ByteArray d, ByteArray z, MonadRandom m)
=> proxy a -> m (EncapsulationKey a, DecapsulationKey a, d, z)
generateOpen :: forall a d z (m :: * -> *) (proxy :: * -> *).
(ParamSet a, ByteArray d, ByteArray z, MonadRandom m) =>
proxy a -> m (EncapsulationKey a, DecapsulationKey a, d, z)
generateOpen proxy a
p = do
d <- Int -> m d
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
z <- getRandomBytes 32
let (ek, dk) = Internal.keyGen p d (B.convert z)
return (ek, dk, d, z)
generateWith :: (ParamSet a, ByteArrayAccess d, ByteArrayAccess z)
=> proxy a -> d -> z -> Maybe (EncapsulationKey a, DecapsulationKey a)
generateWith :: forall a d z (proxy :: * -> *).
(ParamSet a, ByteArrayAccess d, ByteArrayAccess z) =>
proxy a -> d -> z -> Maybe (EncapsulationKey a, DecapsulationKey a)
generateWith proxy a
p d
d z
z
| d -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length d
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. Maybe a
Nothing
| z -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length z
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. Maybe a
Nothing
| Bool
otherwise = (EncapsulationKey a, DecapsulationKey a)
-> Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. a -> Maybe a
Just ((EncapsulationKey a, DecapsulationKey a)
-> Maybe (EncapsulationKey a, DecapsulationKey a))
-> (EncapsulationKey a, DecapsulationKey a)
-> Maybe (EncapsulationKey a, DecapsulationKey a)
forall a b. (a -> b) -> a -> b
$ proxy a
-> d -> ScrubbedBytes -> (EncapsulationKey a, DecapsulationKey a)
forall a d (proxy :: * -> *).
(ParamSet a, ByteArrayAccess d) =>
proxy a
-> d -> ScrubbedBytes -> (EncapsulationKey a, DecapsulationKey a)
Internal.keyGen proxy a
p d
d (z -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert z
z)
encapsulate :: (ParamSet a, MonadRandom m)
=> EncapsulationKey a -> m (SharedSecret a, Ciphertext a)
encapsulate :: forall a (m :: * -> *).
(ParamSet a, MonadRandom m) =>
EncapsulationKey a -> m (SharedSecret a, Ciphertext a)
encapsulate EncapsulationKey a
ek = do
m <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
return (Internal.encaps ek (m :: ScrubbedBytes))
encapsulateWith :: (ParamSet a, ByteArrayAccess m)
=> EncapsulationKey a -> m -> Maybe (SharedSecret a, Ciphertext a)
encapsulateWith :: forall a m.
(ParamSet a, ByteArrayAccess m) =>
EncapsulationKey a -> m -> Maybe (SharedSecret a, Ciphertext a)
encapsulateWith EncapsulationKey a
ek m
m
| m -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length m
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (SharedSecret a, Ciphertext a)
forall a. Maybe a
Nothing
| Bool
otherwise = (SharedSecret a, Ciphertext a)
-> Maybe (SharedSecret a, Ciphertext a)
forall a. a -> Maybe a
Just ((SharedSecret a, Ciphertext a)
-> Maybe (SharedSecret a, Ciphertext a))
-> (SharedSecret a, Ciphertext a)
-> Maybe (SharedSecret a, Ciphertext a)
forall a b. (a -> b) -> a -> b
$ EncapsulationKey a -> m -> (SharedSecret a, Ciphertext a)
forall a m.
(ParamSet a, ByteArrayAccess m) =>
EncapsulationKey a -> m -> (SharedSecret a, Ciphertext a)
Internal.encaps EncapsulationKey a
ek m
m
decapsulate :: ParamSet a => DecapsulationKey a -> Ciphertext a -> SharedSecret a
decapsulate :: forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
decapsulate = DecapsulationKey a -> Ciphertext a -> SharedSecret a
forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
Internal.decaps
checkKeyPair :: (ParamSet a, MonadRandom m)
=> (EncapsulationKey a, DecapsulationKey a) -> m Bool
checkKeyPair :: forall a (m :: * -> *).
(ParamSet a, MonadRandom m) =>
(EncapsulationKey a, DecapsulationKey a) -> m Bool
checkKeyPair (EncapsulationKey a
ek, DecapsulationKey a
dk) = do
m <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
let (kk, ct) = Internal.encaps ek (m :: ScrubbedBytes)
kk' = DecapsulationKey a -> Ciphertext a -> SharedSecret a
forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
Internal.decaps DecapsulationKey a
dk Ciphertext a
ct
return (kk' == kk)