{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.TLS.Handshake.Common13 (
    makeFinished,
    checkFinished,
    makeServerKeyShare,
    makeClientKeyShare,
    fromServerKeyShare,
    makeCertVerify,
    checkCertVerify,
    makePSKBinder,
    replacePSKBinder,
    sendChangeCipherSpec13,
    makeCertRequest,
    createTLS13TicketInfo,
    ageToObfuscatedAge,
    isAgeValid,
    getAge,
    checkFreshness,
    getCurrentTimeFromBase,
    getSessionData13,
    isHashSignatureValid13,
    safeNonNegative32,
    RecvHandshake13M,
    runRecvHandshake13,
    recvHandshake13,
    recvHandshake13hash,
    CipherChoice (..),
    makeCipherChoice,
    initEarlySecret,
    calculateEarlySecret,
    calculateHandshakeSecret,
    calculateApplicationSecret,
    calculateResumptionSecret,
    derivePSK,
    checkClientKeyShareKeyLength,
    checkServerKeyShareKeyLength,
    setRTT,
    computeConfirm,
    updateTranscriptHash13,
    setServerHelloParameters13,
    finishHandshake13,
) where

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.ByteArray (convert)
import qualified Data.ByteString as B
import Data.UnixTime
import Foreign.C.Types (CTime (..))
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import qualified Network.TLS.Crypto.IES as IES

import Network.TLS.Compression
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate (extractCAname)
import Network.TLS.Handshake.Common (unexpected)
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Handshake.TranscriptHash
import Network.TLS.IO
import Network.TLS.IO.Encode
import Network.TLS.Imports
import Network.TLS.KeySchedule
import Network.TLS.MAC
import Network.TLS.Packet13
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
import Network.TLS.Wire

----------------------------------------------------------------

makeFinished :: MonadIO m => Context -> Hash -> Secret -> m Handshake13
makeFinished :: forall (m :: * -> *).
MonadIO m =>
Context -> Hash -> Secret -> m Handshake13
makeFinished Context
ctx Hash
usedHash Secret
baseKey = do
    verifyData <-
        ByteString -> VerifyData
VerifyData (ByteString -> VerifyData)
-> (TranscriptHash -> ByteString) -> TranscriptHash -> VerifyData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hash -> Secret -> TranscriptHash -> ByteString
makeVerifyData Hash
usedHash Secret
baseKey
            (TranscriptHash -> VerifyData) -> m TranscriptHash -> m VerifyData
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> String -> m TranscriptHash
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m TranscriptHash
transcriptHash Context
ctx String
"makeFinished"
    liftIO $ usingState_ ctx $ setVerifyDataForSend verifyData
    pure $ Finished13 verifyData

checkFinished
    :: MonadIO m
    => Context -> Hash -> Secret -> TranscriptHash -> VerifyData -> m ()
checkFinished :: forall (m :: * -> *).
MonadIO m =>
Context -> Hash -> Secret -> TranscriptHash -> VerifyData -> m ()
checkFinished Context
ctx Hash
usedHash Secret
baseKey (TranscriptHash ByteString
hashValue) vd :: VerifyData
vd@(VerifyData ByteString
verifyData) = do
    let verifyData' :: ByteString
verifyData' = Hash -> Secret -> TranscriptHash -> ByteString
makeVerifyData Hash
usedHash Secret
baseKey (TranscriptHash -> ByteString) -> TranscriptHash -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> TranscriptHash
TranscriptHash ByteString
hashValue
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
B.length ByteString
verifyData Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
B.length ByteString
verifyData') (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"broken Finished" AlertDescription
DecodeError
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
verifyData' ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
verifyData) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError String
"finished verification failed"
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ VerifyData -> TLSSt ()
setVerifyDataForRecv VerifyData
vd

makeVerifyData :: Hash -> Secret -> TranscriptHash -> ByteString
makeVerifyData :: Hash -> Secret -> TranscriptHash -> ByteString
makeVerifyData Hash
usedHash Secret
baseKey (TranscriptHash ByteString
th) =
    Hash -> ByteString -> ByteString -> ByteString
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
usedHash ByteString
finishedKey ByteString
th
  where
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash
    finishedKey :: ByteString
finishedKey = Hash -> Secret -> ByteString -> ByteString -> Int -> ByteString
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> Secret -> ByteString -> ByteString -> Int -> ba
hkdfExpandLabel Hash
usedHash Secret
baseKey ByteString
"finished" ByteString
"" Int
hashSize

----------------------------------------------------------------

makeClientKeyShare
    :: Context -> Group -> IO ((Group, IES.GroupPrivate), KeyShareEntry)
makeClientKeyShare :: Context -> Group -> IO ((Group, GroupPrivate), KeyShareEntry)
makeClientKeyShare Context
ctx Group
grp = do
    (cpri, cpub) <- Context -> Group -> IO (GroupPrivate, GroupPublicA)
generateGroup Context
ctx Group
grp
    let wcpub = GroupPublicA -> ByteString
IES.groupEncodePublicA GroupPublicA
cpub
        clientKeyShare = Group -> ByteString -> KeyShareEntry
KeyShareEntry Group
grp ByteString
wcpub
    return ((grp, cpri), clientKeyShare)

makeServerKeyShare :: Context -> KeyShareEntry -> IO (Secret, KeyShareEntry)
makeServerKeyShare :: Context -> KeyShareEntry -> IO (Secret, KeyShareEntry)
makeServerKeyShare Context
ctx (KeyShareEntry Group
grp ByteString
wcpub) = case Either CryptoError GroupPublicA
ecpub of
    Left CryptoError
e -> TLSError -> IO (Secret, KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO (Secret, KeyShareEntry))
-> TLSError -> IO (Secret, KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol (CryptoError -> String
forall a. Show a => a -> String
show CryptoError
e) AlertDescription
IllegalParameter
    Right GroupPublicA
cpub -> do
        ecdhePair <- Context -> GroupPublicA -> IO (Maybe (GroupPublicB, Secret))
encapsulateGroup Context
ctx GroupPublicA
cpub
        case ecdhePair of
            Maybe (GroupPublicB, Secret)
Nothing -> TLSError -> IO (Secret, KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO (Secret, KeyShareEntry))
-> TLSError -> IO (Secret, KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
msgInvalidPublic AlertDescription
IllegalParameter
            Just (GroupPublicB
spub, Secret
share) ->
                let wspub :: ByteString
wspub = GroupPublicB -> ByteString
IES.groupEncodePublicB GroupPublicB
spub
                    serverKeyShare :: KeyShareEntry
serverKeyShare = Group -> ByteString -> KeyShareEntry
KeyShareEntry Group
grp ByteString
wspub
                 in (Secret, KeyShareEntry) -> IO (Secret, KeyShareEntry)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Secret
share, KeyShareEntry
serverKeyShare)
  where
    ecpub :: Either CryptoError GroupPublicA
ecpub = Group -> ByteString -> Either CryptoError GroupPublicA
IES.groupDecodePublicA Group
grp ByteString
wcpub
    msgInvalidPublic :: String
msgInvalidPublic = String
"invalid client " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Group -> String
forall a. Show a => a -> String
show Group
grp String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" public key"

fromServerKeyShare
    :: KeyShareEntry -> [(Group, IES.GroupPrivate)] -> IO Secret
fromServerKeyShare :: KeyShareEntry -> [(Group, GroupPrivate)] -> IO Secret
fromServerKeyShare (KeyShareEntry Group
grp ByteString
wspub) [(Group, GroupPrivate)]
grpCpris = case Either CryptoError GroupPublicB
espub of
    Left CryptoError
e -> TLSError -> IO Secret
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Secret) -> TLSError -> IO Secret
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol (CryptoError -> String
forall a. Show a => a -> String
show CryptoError
e) AlertDescription
IllegalParameter
    Right GroupPublicB
spub -> case Group -> [(Group, GroupPrivate)] -> Maybe GroupPrivate
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Group
grp [(Group, GroupPrivate)]
grpCpris of
        Maybe GroupPrivate
Nothing -> TLSError -> IO Secret
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
err
        Just GroupPrivate
cpri -> case GroupPublicB -> GroupPrivate -> Maybe Secret
IES.groupDecapsulate GroupPublicB
spub GroupPrivate
cpri of
            Just Secret
shared -> Secret -> IO Secret
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Secret
shared
            Maybe Secret
Nothing -> TLSError -> IO Secret
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
err
  where
    err :: TLSError
err = String -> AlertDescription -> TLSError
Error_Protocol String
"cannot generate a shared secret on (EC)DH" AlertDescription
IllegalParameter
    espub :: Either CryptoError GroupPublicB
espub = Group -> ByteString -> Either CryptoError GroupPublicB
IES.groupDecodePublicB Group
grp ByteString
wspub

----------------------------------------------------------------

serverContextString :: ByteString
serverContextString :: ByteString
serverContextString = ByteString
"TLS 1.3, server CertificateVerify"

clientContextString :: ByteString
clientContextString :: ByteString
clientContextString = ByteString
"TLS 1.3, client CertificateVerify"

makeCertVerify
    :: MonadIO m
    => Context
    -> PubKey
    -> HashAndSignatureAlgorithm
    -> TranscriptHash
    -> m Handshake13
makeCertVerify :: forall (m :: * -> *).
MonadIO m =>
Context
-> PubKey
-> HashAndSignatureAlgorithm
-> TranscriptHash
-> m Handshake13
makeCertVerify Context
ctx PubKey
pub HashAndSignatureAlgorithm
hs (TranscriptHash ByteString
hashValue) = do
    role <- IO Role -> m Role
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Role -> m Role) -> IO Role -> m Role
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
getRole
    let ctxStr
            | Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole = ByteString
clientContextString
            | Bool
otherwise = ByteString
serverContextString
        target = ByteString -> ByteString -> ByteString
makeTarget ByteString
ctxStr ByteString
hashValue
    CertVerify13 . DigitallySigned hs <$> sign ctx pub hs target

checkCertVerify
    :: MonadIO m
    => Context
    -> PubKey
    -> HashAndSignatureAlgorithm
    -> Signature
    -> ByteString
    -> m Bool
checkCertVerify :: forall (m :: * -> *).
MonadIO m =>
Context
-> PubKey
-> HashAndSignatureAlgorithm
-> ByteString
-> ByteString
-> m Bool
checkCertVerify Context
ctx PubKey
pub HashAndSignatureAlgorithm
hs ByteString
signature ByteString
hashValue
    | PubKey
pub PubKey -> HashAndSignatureAlgorithm -> Bool
`signatureCompatible13` HashAndSignatureAlgorithm
hs = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
        role <- Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
getRole
        let ctxStr
                | Role
role Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole = ByteString
serverContextString -- opposite context
                | Bool
otherwise = ByteString
clientContextString
            target = ByteString -> ByteString -> ByteString
makeTarget ByteString
ctxStr ByteString
hashValue
            sigParams = PubKey -> HashAndSignatureAlgorithm -> SignatureParams
signatureParams PubKey
pub HashAndSignatureAlgorithm
hs
        checkHashSignatureValid13 hs
        checkSupportedHashSignature ctx hs
        verifyPublic ctx sigParams target signature
    | Bool
otherwise = Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

makeTarget :: ByteString -> ByteString -> ByteString
makeTarget :: ByteString -> ByteString -> ByteString
makeTarget ByteString
contextString ByteString
hashValue = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ do
    ByteString -> Put
putBytes (ByteString -> Put) -> ByteString -> Put
forall a b. (a -> b) -> a -> b
$ Int -> CompressionID -> ByteString
B.replicate Int
64 CompressionID
32
    ByteString -> Put
putBytes ByteString
contextString
    Putter CompressionID
putWord8 CompressionID
0
    ByteString -> Put
putBytes ByteString
hashValue

sign
    :: MonadIO m
    => Context
    -> PubKey
    -> HashAndSignatureAlgorithm
    -> ByteString
    -> m Signature
sign :: forall (m :: * -> *).
MonadIO m =>
Context
-> PubKey
-> HashAndSignatureAlgorithm
-> ByteString
-> m ByteString
sign Context
ctx PubKey
pub HashAndSignatureAlgorithm
hs ByteString
target = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ do
    role <- Context -> TLSSt Role -> IO Role
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
getRole
    let sigParams = PubKey -> HashAndSignatureAlgorithm -> SignatureParams
signatureParams PubKey
pub HashAndSignatureAlgorithm
hs
    signPrivate ctx role sigParams target

----------------------------------------------------------------

makePSKBinder
    :: BaseSecret EarlySecret
    -> Hash
    -> Int
    -> ByteString
    -- ^ Encoded client hello
    -> ByteString
makePSKBinder :: BaseSecret EarlySecret -> Hash -> Int -> ByteString -> ByteString
makePSKBinder (BaseSecret Secret
sec) Hash
usedHash Int
truncLen ByteString
ech =
    Hash -> Secret -> TranscriptHash -> ByteString
makeVerifyData Hash
usedHash Secret
binderKey TranscriptHash
hChTruncated
  where
    hChTruncated :: TranscriptHash
hChTruncated = ByteString -> TranscriptHash
TranscriptHash (ByteString -> TranscriptHash) -> ByteString -> TranscriptHash
forall a b. (a -> b) -> a -> b
$ Hash -> ByteString -> ByteString
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
usedHash (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
trunc ByteString
ech
    th :: TranscriptHash
th = ByteString -> TranscriptHash
TranscriptHash (ByteString -> TranscriptHash) -> ByteString -> TranscriptHash
forall a b. (a -> b) -> a -> b
$ Hash -> ByteString -> ByteString
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
usedHash ByteString
""
    binderKey :: Secret
binderKey = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
sec ByteString
"res binder" TranscriptHash
th
    trunc :: ByteString -> ByteString
trunc ByteString
x = Int -> ByteString -> ByteString
B.take Int
takeLen ByteString
x
      where
        totalLen :: Int
totalLen = ByteString -> Int
B.length ByteString
x
        takeLen :: Int
takeLen = Int
totalLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
truncLen

replacePSKBinder :: ByteString -> [ByteString] -> ByteString
replacePSKBinder :: ByteString -> [ByteString] -> ByteString
replacePSKBinder ByteString
pskz [ByteString]
bds = ByteString
tLidentities ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
binders
  where
    tLidentities :: ByteString
tLidentities = Int -> ByteString -> ByteString
B.take (ByteString -> Int
B.length ByteString
pskz Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
binders) ByteString
pskz
    -- See instance Extension PreSharedKey
    binders :: ByteString
binders = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Put
putOpaque16 (ByteString -> Put) -> ByteString -> Put
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut ((ByteString -> Put) -> [ByteString] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ByteString -> Put
putBinder [ByteString]
bds)
    putBinder :: ByteString -> Put
putBinder = ByteString -> Put
putOpaque8

----------------------------------------------------------------

sendChangeCipherSpec13 :: Monoid b => Context -> PacketFlightM b ()
sendChangeCipherSpec13 :: forall b. Monoid b => Context -> PacketFlightM b ()
sendChangeCipherSpec13 Context
ctx = do
    sent <- Context -> HandshakeM Bool -> PacketFlightM b Bool
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM Bool -> PacketFlightM b Bool)
-> HandshakeM Bool -> PacketFlightM b Bool
forall a b. (a -> b) -> a -> b
$ do
        b <- HandshakeM Bool
getCCS13Sent
        unless b $ setCCS13Sent True
        return b
    unless sent $ loadPacket13 ctx ChangeCipherSpec13

----------------------------------------------------------------

makeCertRequest
    :: ServerParams -> Context -> CertReqContext -> Bool -> Handshake13
makeCertRequest :: ServerParams -> Context -> ByteString -> Bool -> Handshake13
makeCertRequest ServerParams
sparams Context
ctx ByteString
certReqCtx Bool
zlib =
    let sigAlgs :: SignatureAlgorithms
sigAlgs = [HashAndSignatureAlgorithm] -> SignatureAlgorithms
SignatureAlgorithms ([HashAndSignatureAlgorithm] -> SignatureAlgorithms)
-> [HashAndSignatureAlgorithm] -> SignatureAlgorithms
forall a b. (a -> b) -> a -> b
$ Supported -> [HashAndSignatureAlgorithm]
supportedHashSignatures (Supported -> [HashAndSignatureAlgorithm])
-> Supported -> [HashAndSignatureAlgorithm]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
        signatureAlgExt :: Maybe ExtensionRaw
signatureAlgExt = ExtensionRaw -> Maybe ExtensionRaw
forall a. a -> Maybe a
Just (ExtensionRaw -> Maybe ExtensionRaw)
-> ExtensionRaw -> Maybe ExtensionRaw
forall a b. (a -> b) -> a -> b
$ SignatureAlgorithms -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw SignatureAlgorithms
sigAlgs

        compCertExt :: Maybe ExtensionRaw
compCertExt
            | Bool
zlib = ExtensionRaw -> Maybe ExtensionRaw
forall a. a -> Maybe a
Just (ExtensionRaw -> Maybe ExtensionRaw)
-> ExtensionRaw -> Maybe ExtensionRaw
forall a b. (a -> b) -> a -> b
$ CompressCertificate -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (CompressCertificate -> ExtensionRaw)
-> CompressCertificate -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ [CertificateCompressionAlgorithm] -> CompressCertificate
CompressCertificate [CertificateCompressionAlgorithm
CCA_Zlib]
            | Bool
otherwise = Maybe ExtensionRaw
forall a. Maybe a
Nothing

        caDns :: [DistinguishedName]
caDns = (SignedCertificate -> DistinguishedName)
-> [SignedCertificate] -> [DistinguishedName]
forall a b. (a -> b) -> [a] -> [b]
map SignedCertificate -> DistinguishedName
extractCAname ([SignedCertificate] -> [DistinguishedName])
-> [SignedCertificate] -> [DistinguishedName]
forall a b. (a -> b) -> a -> b
$ ServerParams -> [SignedCertificate]
serverCACertificates ServerParams
sparams
        caExt :: Maybe ExtensionRaw
caExt
            | [DistinguishedName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DistinguishedName]
caDns = Maybe ExtensionRaw
forall a. Maybe a
Nothing
            | Bool
otherwise = ExtensionRaw -> Maybe ExtensionRaw
forall a. a -> Maybe a
Just (ExtensionRaw -> Maybe ExtensionRaw)
-> ExtensionRaw -> Maybe ExtensionRaw
forall a b. (a -> b) -> a -> b
$ CertificateAuthorities -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (CertificateAuthorities -> ExtensionRaw)
-> CertificateAuthorities -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ [DistinguishedName] -> CertificateAuthorities
CertificateAuthorities [DistinguishedName]
caDns

        crexts :: [ExtensionRaw]
crexts =
            [Maybe ExtensionRaw] -> [ExtensionRaw]
forall a. [Maybe a] -> [a]
catMaybes
                [ {- 0x0d -} Maybe ExtensionRaw
signatureAlgExt
                , {- 0x1b -} Maybe ExtensionRaw
compCertExt
                , {- 0x2f -} Maybe ExtensionRaw
caExt
                ]
     in ByteString -> [ExtensionRaw] -> Handshake13
CertRequest13 ByteString
certReqCtx [ExtensionRaw]
crexts

----------------------------------------------------------------

createTLS13TicketInfo
    :: Second -> Either Context Second -> Maybe Millisecond -> IO TLS13TicketInfo
createTLS13TicketInfo :: Second
-> Either Context Second -> Maybe Millisecond -> IO TLS13TicketInfo
createTLS13TicketInfo Second
life Either Context Second
ecw Maybe Millisecond
mrtt = do
    -- Left:  serverSendTime
    -- Right: clientReceiveTime
    bTime <- IO Millisecond
getCurrentTimeFromBase
    add <- case ecw of
        Left Context
ctx -> (Second -> CompressionID -> Second)
-> Second -> ByteString -> Second
forall a. (a -> CompressionID -> a) -> a -> ByteString -> a
B.foldl' Second -> CompressionID -> Second
forall {a} {a}. (Integral a, Num a) => a -> a -> a
(*+) Second
0 (ByteString -> Second) -> IO ByteString -> IO Second
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
4
        Right Second
ad -> Second -> IO Second
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Second
ad
    return $
        TLS13TicketInfo
            { lifetime = life
            , ageAdd = add
            , txrxTime = bTime
            , estimatedRTT = mrtt
            }
  where
    a
x *+ :: a -> a -> a
*+ a
y = a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
256 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
y

ageToObfuscatedAge :: Second -> TLS13TicketInfo -> Second
ageToObfuscatedAge :: Second -> TLS13TicketInfo -> Second
ageToObfuscatedAge Second
age TLS13TicketInfo{Maybe Millisecond
Second
Millisecond
lifetime :: TLS13TicketInfo -> Second
ageAdd :: TLS13TicketInfo -> Second
txrxTime :: TLS13TicketInfo -> Millisecond
estimatedRTT :: TLS13TicketInfo -> Maybe Millisecond
lifetime :: Second
ageAdd :: Second
txrxTime :: Millisecond
estimatedRTT :: Maybe Millisecond
..} = Second
obfage
  where
    obfage :: Second
obfage = Second
age Second -> Second -> Second
forall a. Num a => a -> a -> a
+ Second
ageAdd

obfuscatedAgeToAge :: Second -> TLS13TicketInfo -> Second
obfuscatedAgeToAge :: Second -> TLS13TicketInfo -> Second
obfuscatedAgeToAge Second
obfage TLS13TicketInfo{Maybe Millisecond
Second
Millisecond
lifetime :: TLS13TicketInfo -> Second
ageAdd :: TLS13TicketInfo -> Second
txrxTime :: TLS13TicketInfo -> Millisecond
estimatedRTT :: TLS13TicketInfo -> Maybe Millisecond
lifetime :: Second
ageAdd :: Second
txrxTime :: Millisecond
estimatedRTT :: Maybe Millisecond
..} = Second
age
  where
    age :: Second
age = Second
obfage Second -> Second -> Second
forall a. Num a => a -> a -> a
- Second
ageAdd

isAgeValid :: Second -> TLS13TicketInfo -> Bool
isAgeValid :: Second -> TLS13TicketInfo -> Bool
isAgeValid Second
age TLS13TicketInfo{Maybe Millisecond
Second
Millisecond
lifetime :: TLS13TicketInfo -> Second
ageAdd :: TLS13TicketInfo -> Second
txrxTime :: TLS13TicketInfo -> Millisecond
estimatedRTT :: TLS13TicketInfo -> Maybe Millisecond
lifetime :: Second
ageAdd :: Second
txrxTime :: Millisecond
estimatedRTT :: Maybe Millisecond
..} = Second
age Second -> Second -> Bool
forall a. Ord a => a -> a -> Bool
<= Second
lifetime Second -> Second -> Second
forall a. Num a => a -> a -> a
* Second
1000

getAge :: TLS13TicketInfo -> IO Second
getAge :: TLS13TicketInfo -> IO Second
getAge TLS13TicketInfo{Maybe Millisecond
Second
Millisecond
lifetime :: TLS13TicketInfo -> Second
ageAdd :: TLS13TicketInfo -> Second
txrxTime :: TLS13TicketInfo -> Millisecond
estimatedRTT :: TLS13TicketInfo -> Maybe Millisecond
lifetime :: Second
ageAdd :: Second
txrxTime :: Millisecond
estimatedRTT :: Maybe Millisecond
..} = do
    let clientReceiveTime :: Millisecond
clientReceiveTime = Millisecond
txrxTime
    clientSendTime <- IO Millisecond
getCurrentTimeFromBase
    return $ fromIntegral (clientSendTime - clientReceiveTime) -- milliseconds

checkFreshness :: TLS13TicketInfo -> Second -> IO Bool
checkFreshness :: TLS13TicketInfo -> Second -> IO Bool
checkFreshness tinfo :: TLS13TicketInfo
tinfo@TLS13TicketInfo{Maybe Millisecond
Second
Millisecond
lifetime :: TLS13TicketInfo -> Second
ageAdd :: TLS13TicketInfo -> Second
txrxTime :: TLS13TicketInfo -> Millisecond
estimatedRTT :: TLS13TicketInfo -> Maybe Millisecond
lifetime :: Second
ageAdd :: Second
txrxTime :: Millisecond
estimatedRTT :: Maybe Millisecond
..} Second
obfAge = do
    serverReceiveTime <- IO Millisecond
getCurrentTimeFromBase
    let freshness =
            if Millisecond
expectedArrivalTime Millisecond -> Millisecond -> Bool
forall a. Ord a => a -> a -> Bool
> Millisecond
serverReceiveTime
                then Millisecond
expectedArrivalTime Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
- Millisecond
serverReceiveTime
                else Millisecond
serverReceiveTime Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
- Millisecond
expectedArrivalTime
    -- Some implementations round age up to second.
    -- We take max of 2000 and rtt in the case where rtt is too small.
    let tolerance = Millisecond -> Millisecond -> Millisecond
forall a. Ord a => a -> a -> a
max Millisecond
2000 Millisecond
rtt
        isFresh = Millisecond
freshness Millisecond -> Millisecond -> Bool
forall a. Ord a => a -> a -> Bool
< Millisecond
tolerance
    return $ isAlive && isFresh
  where
    serverSendTime :: Millisecond
serverSendTime = Millisecond
txrxTime
    rtt :: Millisecond
rtt = Maybe Millisecond -> Millisecond
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Millisecond
estimatedRTT
    age :: Second
age = Second -> TLS13TicketInfo -> Second
obfuscatedAgeToAge Second
obfAge TLS13TicketInfo
tinfo
    expectedArrivalTime :: Millisecond
expectedArrivalTime = Millisecond
serverSendTime Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
+ Millisecond
rtt Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
+ Second -> Millisecond
forall a b. (Integral a, Num b) => a -> b
fromIntegral Second
age
    isAlive :: Bool
isAlive = Second -> TLS13TicketInfo -> Bool
isAgeValid Second
age TLS13TicketInfo
tinfo

getCurrentTimeFromBase :: IO Millisecond
getCurrentTimeFromBase :: IO Millisecond
getCurrentTimeFromBase = UnixTime -> Millisecond
millisecondsFromBase (UnixTime -> Millisecond) -> IO UnixTime -> IO Millisecond
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UnixTime
getUnixTime

millisecondsFromBase :: UnixTime -> Millisecond
millisecondsFromBase :: UnixTime -> Millisecond
millisecondsFromBase (UnixTime (CTime Int64
s) Int32
us) =
    Int64 -> Millisecond
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
s Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
base) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1000) Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
+ Int32 -> Millisecond
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
us Int32 -> Int32 -> Int32
forall a. Integral a => a -> a -> a
`div` Int32
1000)
  where
    base :: Int64
base = Int64
1483228800

-- UnixTime (CTime base) _= parseUnixTimeGMT webDateFormat "Sun, 01 Jan 2017 00:00:00 GMT"

----------------------------------------------------------------

getSessionData13
    :: Context -> Cipher -> TLS13TicketInfo -> Int -> ByteString -> IO SessionData
getSessionData13 :: Context
-> Cipher -> TLS13TicketInfo -> Int -> ByteString -> IO SessionData
getSessionData13 Context
ctx Cipher
usedCipher TLS13TicketInfo
tinfo Int
maxSize ByteString
psk = do
    ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    malpn <- usingState_ ctx getNegotiatedProtocol
    sni <- usingState_ ctx getClientSNI
    mgrp <- usingHState ctx getSupportedGroup
    return
        SessionData
            { sessionVersion = ver
            , sessionCipher = cipherID usedCipher
            , sessionCompression = 0
            , sessionClientSNI = sni
            , sessionSecret = psk
            , sessionGroup = mgrp
            , sessionTicketInfo = Just tinfo
            , sessionALPN = malpn
            , sessionMaxEarlyDataSize = maxSize
            , sessionFlags = []
            }

----------------------------------------------------------------

-- Word32 is used in TLS 1.3 protocol.
-- Int is used for API for Haskell TLS because it is natural.
-- If Int is 64 bits, users can specify bigger number than Word32.
-- If Int is 32 bits, 2^31 or larger may be converted into minus numbers.
safeNonNegative32 :: (Num a, Ord a, FiniteBits a) => a -> a
safeNonNegative32 :: forall a. (Num a, Ord a, FiniteBits a) => a -> a
safeNonNegative32 a
x
    | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 = a
0
    | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
32 = a
x
    | Bool
otherwise = a
x a -> a -> a
forall a. Ord a => a -> a -> a
`min` Second -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Second
forall a. Bounded a => a
maxBound :: Word32)

----------------------------------------------------------------

newtype RecvHandshake13M m a
    = RecvHandshake13M (StateT [Handshake13R] m a)
    deriving ((forall a b.
 (a -> b) -> RecvHandshake13M m a -> RecvHandshake13M m b)
-> (forall a b. a -> RecvHandshake13M m b -> RecvHandshake13M m a)
-> Functor (RecvHandshake13M m)
forall a b. a -> RecvHandshake13M m b -> RecvHandshake13M m a
forall a b.
(a -> b) -> RecvHandshake13M m a -> RecvHandshake13M m b
forall (m :: * -> *) a b.
Functor m =>
a -> RecvHandshake13M m b -> RecvHandshake13M m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> RecvHandshake13M m a -> RecvHandshake13M m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> RecvHandshake13M m a -> RecvHandshake13M m b
fmap :: forall a b.
(a -> b) -> RecvHandshake13M m a -> RecvHandshake13M m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> RecvHandshake13M m b -> RecvHandshake13M m a
<$ :: forall a b. a -> RecvHandshake13M m b -> RecvHandshake13M m a
Functor, Functor (RecvHandshake13M m)
Functor (RecvHandshake13M m) =>
(forall a. a -> RecvHandshake13M m a)
-> (forall a b.
    RecvHandshake13M m (a -> b)
    -> RecvHandshake13M m a -> RecvHandshake13M m b)
-> (forall a b c.
    (a -> b -> c)
    -> RecvHandshake13M m a
    -> RecvHandshake13M m b
    -> RecvHandshake13M m c)
-> (forall a b.
    RecvHandshake13M m a
    -> RecvHandshake13M m b -> RecvHandshake13M m b)
-> (forall a b.
    RecvHandshake13M m a
    -> RecvHandshake13M m b -> RecvHandshake13M m a)
-> Applicative (RecvHandshake13M m)
forall a. a -> RecvHandshake13M m a
forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m a
forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
forall a b.
RecvHandshake13M m (a -> b)
-> RecvHandshake13M m a -> RecvHandshake13M m b
forall a b c.
(a -> b -> c)
-> RecvHandshake13M m a
-> RecvHandshake13M m b
-> RecvHandshake13M m c
forall (m :: * -> *). Monad m => Functor (RecvHandshake13M m)
forall (m :: * -> *) a. Monad m => a -> RecvHandshake13M m a
forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m a
forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m (a -> b)
-> RecvHandshake13M m a -> RecvHandshake13M m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> RecvHandshake13M m a
-> RecvHandshake13M m b
-> RecvHandshake13M m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> RecvHandshake13M m a
pure :: forall a. a -> RecvHandshake13M m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m (a -> b)
-> RecvHandshake13M m a -> RecvHandshake13M m b
<*> :: forall a b.
RecvHandshake13M m (a -> b)
-> RecvHandshake13M m a -> RecvHandshake13M m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> RecvHandshake13M m a
-> RecvHandshake13M m b
-> RecvHandshake13M m c
liftA2 :: forall a b c.
(a -> b -> c)
-> RecvHandshake13M m a
-> RecvHandshake13M m b
-> RecvHandshake13M m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
*> :: forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m a
<* :: forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m a
Applicative, Applicative (RecvHandshake13M m)
Applicative (RecvHandshake13M m) =>
(forall a b.
 RecvHandshake13M m a
 -> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b)
-> (forall a b.
    RecvHandshake13M m a
    -> RecvHandshake13M m b -> RecvHandshake13M m b)
-> (forall a. a -> RecvHandshake13M m a)
-> Monad (RecvHandshake13M m)
forall a. a -> RecvHandshake13M m a
forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
forall a b.
RecvHandshake13M m a
-> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b
forall (m :: * -> *). Monad m => Applicative (RecvHandshake13M m)
forall (m :: * -> *) a. Monad m => a -> RecvHandshake13M m a
forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b
>>= :: forall a b.
RecvHandshake13M m a
-> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
>> :: forall a b.
RecvHandshake13M m a
-> RecvHandshake13M m b -> RecvHandshake13M m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> RecvHandshake13M m a
return :: forall a. a -> RecvHandshake13M m a
Monad, Monad (RecvHandshake13M m)
Monad (RecvHandshake13M m) =>
(forall a. IO a -> RecvHandshake13M m a)
-> MonadIO (RecvHandshake13M m)
forall a. IO a -> RecvHandshake13M m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (RecvHandshake13M m)
forall (m :: * -> *) a. MonadIO m => IO a -> RecvHandshake13M m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> RecvHandshake13M m a
liftIO :: forall a. IO a -> RecvHandshake13M m a
MonadIO)

recvHandshake13
    :: MonadIO m
    => Context
    -> (Handshake13 -> RecvHandshake13M m a)
    -> RecvHandshake13M m a
recvHandshake13 :: forall (m :: * -> *) a.
MonadIO m =>
Context
-> (Handshake13 -> RecvHandshake13M m a) -> RecvHandshake13M m a
recvHandshake13 Context
ctx Handshake13 -> RecvHandshake13M m a
f = Context -> RecvHandshake13M m Handshake13R
forall (m :: * -> *).
MonadIO m =>
Context -> RecvHandshake13M m Handshake13R
getHandshake13 Context
ctx RecvHandshake13M m Handshake13R
-> (Handshake13R -> RecvHandshake13M m a) -> RecvHandshake13M m a
forall a b.
RecvHandshake13M m a
-> (a -> RecvHandshake13M m b) -> RecvHandshake13M m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Handshake13
h, [ByteString]
_b) -> Handshake13 -> RecvHandshake13M m a
f Handshake13
h

recvHandshake13hash
    :: MonadIO m
    => Context
    -> String
    -> (TranscriptHash -> Handshake13 -> RecvHandshake13M m a)
    -> RecvHandshake13M m a
recvHandshake13hash :: forall (m :: * -> *) a.
MonadIO m =>
Context
-> String
-> (TranscriptHash -> Handshake13 -> RecvHandshake13M m a)
-> RecvHandshake13M m a
recvHandshake13hash Context
ctx String
label TranscriptHash -> Handshake13 -> RecvHandshake13M m a
f = do
    d <- Context -> String -> RecvHandshake13M m TranscriptHash
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m TranscriptHash
transcriptHash Context
ctx String
label
    getHandshake13 ctx >>= \(Handshake13
h, [ByteString]
_b) -> TranscriptHash -> Handshake13 -> RecvHandshake13M m a
f TranscriptHash
d Handshake13
h

getHandshake13
    :: MonadIO m => Context -> RecvHandshake13M m Handshake13R
getHandshake13 :: forall (m :: * -> *).
MonadIO m =>
Context -> RecvHandshake13M m Handshake13R
getHandshake13 Context
ctx = StateT [Handshake13R] m Handshake13R
-> RecvHandshake13M m Handshake13R
forall (m :: * -> *) a.
StateT [Handshake13R] m a -> RecvHandshake13M m a
RecvHandshake13M (StateT [Handshake13R] m Handshake13R
 -> RecvHandshake13M m Handshake13R)
-> StateT [Handshake13R] m Handshake13R
-> RecvHandshake13M m Handshake13R
forall a b. (a -> b) -> a -> b
$ do
    currentState <- StateT [Handshake13R] m [Handshake13R]
forall s (m :: * -> *). MonadState s m => m s
get
    case currentState of
        Handshake13R
hb : [Handshake13R]
hbs -> Handshake13R
-> [Handshake13R] -> StateT [Handshake13R] m Handshake13R
forall {m :: * -> *} {s}.
(MonadIO m, MonadState s m) =>
Handshake13R -> s -> m Handshake13R
found Handshake13R
hb [Handshake13R]
hbs
        [Handshake13R]
_ -> StateT [Handshake13R] m Handshake13R
recvLoop
  where
    found :: Handshake13R -> s -> m Handshake13R
found Handshake13R
hb s
hbs = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Context -> Handshake13R -> IO ()
updateTranscriptHash13 Context
ctx Handshake13R
hb) m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
hbs m () -> m Handshake13R -> m Handshake13R
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handshake13R -> m Handshake13R
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Handshake13R
hb
    recvLoop :: StateT [Handshake13R] m Handshake13R
recvLoop = do
        epkt <- IO (Either TLSError Packet13)
-> StateT [Handshake13R] m (Either TLSError Packet13)
forall a. IO a -> StateT [Handshake13R] m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Context -> IO (Either TLSError Packet13)
recvPacket13 Context
ctx)
        case epkt of
            Right (Handshake13 [] [[ByteString]]
_) -> String -> StateT [Handshake13R] m Handshake13R
forall a. HasCallStack => String -> a
error String
"invalid recvPacket13 result"
            Right (Handshake13 (Handshake13
h : [Handshake13]
hs) ([ByteString]
b : [[ByteString]]
bs)) -> Handshake13R
-> [Handshake13R] -> StateT [Handshake13R] m Handshake13R
forall {m :: * -> *} {s}.
(MonadIO m, MonadState s m) =>
Handshake13R -> s -> m Handshake13R
found (Handshake13
h, [ByteString]
b) ([Handshake13R] -> StateT [Handshake13R] m Handshake13R)
-> [Handshake13R] -> StateT [Handshake13R] m Handshake13R
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> [[ByteString]] -> [Handshake13R]
forall a b. [a] -> [b] -> [(a, b)]
zip [Handshake13]
hs [[ByteString]]
bs
            Right Packet13
ChangeCipherSpec13 -> do
                alreadyReceived <- IO Bool -> StateT [Handshake13R] m Bool
forall a. IO a -> StateT [Handshake13R] m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> StateT [Handshake13R] m Bool)
-> IO Bool -> StateT [Handshake13R] m Bool
forall a b. (a -> b) -> a -> b
$ Context -> HandshakeM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM Bool
getCCS13Recv
                if alreadyReceived
                    then
                        liftIO $ throwCore $ Error_Protocol "multiple CSS in TLS 1.3" UnexpectedMessage
                    else do
                        liftIO $ usingHState ctx $ setCCS13Recv True
                        recvLoop
            Right (Alert13 [(AlertLevel, AlertDescription)]
_) -> TLSError -> StateT [Handshake13R] m Handshake13R
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
Error_TCP_Terminate
            Right Packet13
x -> String -> Maybe String -> StateT [Handshake13R] m Handshake13R
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Packet13 -> String
forall a. Show a => a -> String
show Packet13
x) (String -> Maybe String
forall a. a -> Maybe a
Just String
"handshake 13")
            Left TLSError
err -> TLSError -> StateT [Handshake13R] m Handshake13R
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
err

runRecvHandshake13 :: MonadIO m => RecvHandshake13M m a -> m a
runRecvHandshake13 :: forall (m :: * -> *) a. MonadIO m => RecvHandshake13M m a -> m a
runRecvHandshake13 (RecvHandshake13M StateT [Handshake13R] m a
f) = do
    (result, new) <- StateT [Handshake13R] m a
-> [Handshake13R] -> m (a, [Handshake13R])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT [Handshake13R] m a
f []
    unless (null new) $ unexpected "spurious handshake 13" Nothing
    return result

----------------------------------------------------------------

-- some hash/signature combinations have been deprecated in TLS13 and should
-- not be used
checkHashSignatureValid13 :: HashAndSignatureAlgorithm -> IO ()
checkHashSignatureValid13 :: HashAndSignatureAlgorithm -> IO ()
checkHashSignatureValid13 HashAndSignatureAlgorithm
hs =
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 HashAndSignatureAlgorithm
hs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        let msg :: String
msg = String
"invalid TLS13 hash and signature algorithm: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ HashAndSignatureAlgorithm -> String
forall a. Show a => a -> String
show HashAndSignatureAlgorithm
hs
         in TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
msg AlertDescription
IllegalParameter

isHashSignatureValid13 :: HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 :: HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 HashAndSignatureAlgorithm
hs = HashAndSignatureAlgorithm
hs HashAndSignatureAlgorithm -> [HashAndSignatureAlgorithm] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HashAndSignatureAlgorithm]
signatureSchemesForTLS13

{-
isHashSignatureValid13 (HashIntrinsic, s) =
    s
        `elem` [ SignatureRSApssRSAeSHA256
               , SignatureRSApssRSAeSHA384
               , SignatureRSApssRSAeSHA512
               , SignatureEd25519
               , SignatureEd448
               , SignatureRSApsspssSHA256
               , SignatureRSApsspssSHA384
               , SignatureRSApsspssSHA512
               ]
isHashSignatureValid13 (h, SignatureECDSA) =
    h `elem` [HashSHA256, HashSHA384, HashSHA512]
isHashSignatureValid13 _ = False
-}

----------------------------------------------------------------

calculateEarlySecret
    :: Context
    -> CipherChoice
    -> Either ByteString (BaseSecret EarlySecret)
    -> IO (SecretPair EarlySecret)
calculateEarlySecret :: Context
-> CipherChoice
-> Either ByteString (BaseSecret EarlySecret)
-> IO (SecretPair EarlySecret)
calculateEarlySecret Context
ctx CipherChoice
choice Either ByteString (BaseSecret EarlySecret)
maux = do
    (_ch, b) <- Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString])
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString]))
-> IO (Maybe (ClientHello, [ByteString]))
-> IO (ClientHello, [ByteString])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> HandshakeM (Maybe (ClientHello, [ByteString]))
-> IO (Maybe (ClientHello, [ByteString]))
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe (ClientHello, [ByteString]))
getClientHello
    let hCh = ByteString -> TranscriptHash
TranscriptHash (ByteString -> TranscriptHash) -> ByteString -> TranscriptHash
forall a b. (a -> b) -> a -> b
$ Hash -> [ByteString] -> ByteString
hashChunks Hash
usedHash [ByteString]
b
    let earlySecret = case Either ByteString (BaseSecret EarlySecret)
maux of
            Right (BaseSecret Secret
sec) -> Secret
sec
            Left ByteString
psk -> Hash -> Secret -> Secret -> Secret
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ba -> ba
hkdfExtract Hash
usedHash Secret
zero (ByteString -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
psk)
        clientEarlySecret = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
earlySecret ByteString
"c e traffic" TranscriptHash
hCh
        cets = Secret -> ClientTrafficSecret EarlySecret
forall a. Secret -> ClientTrafficSecret a
ClientTrafficSecret Secret
clientEarlySecret :: ClientTrafficSecret EarlySecret
    logKey ctx cets
    return $ SecretPair (BaseSecret earlySecret) cets
  where
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice
    zero :: Secret
zero = CipherChoice -> Secret
cZero CipherChoice
choice

initEarlySecret :: CipherChoice -> Maybe ByteString -> BaseSecret EarlySecret
initEarlySecret :: CipherChoice -> Maybe ByteString -> BaseSecret EarlySecret
initEarlySecret CipherChoice
choice Maybe ByteString
mpsk = Secret -> BaseSecret EarlySecret
forall a. Secret -> BaseSecret a
BaseSecret Secret
sec
  where
    sec :: Secret
sec = Hash -> Secret -> Secret -> Secret
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ba -> ba
hkdfExtract Hash
usedHash Secret
zero Secret
zeroOrPSK
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice
    zero :: Secret
zero = CipherChoice -> Secret
cZero CipherChoice
choice
    zeroOrPSK :: Secret
zeroOrPSK = Secret -> Maybe Secret -> Secret
forall a. a -> Maybe a -> a
fromMaybe Secret
zero (ByteString -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ByteString -> Secret) -> Maybe ByteString -> Maybe Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
mpsk)

calculateHandshakeSecret
    :: Context
    -> CipherChoice
    -> BaseSecret EarlySecret
    -> Secret
    -> IO (SecretTriple HandshakeSecret)
calculateHandshakeSecret :: Context
-> CipherChoice
-> BaseSecret EarlySecret
-> Secret
-> IO (SecretTriple HandshakeSecret)
calculateHandshakeSecret Context
ctx CipherChoice
choice (BaseSecret Secret
sec) Secret
ecdhe = do
    hChSh <- Context -> String -> IO TranscriptHash
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m TranscriptHash
transcriptHash Context
ctx String
"CH..SH"
    let th = ByteString -> TranscriptHash
TranscriptHash (ByteString -> TranscriptHash) -> ByteString -> TranscriptHash
forall a b. (a -> b) -> a -> b
$ Hash -> ByteString -> ByteString
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
usedHash ByteString
""
        handshakeSecret =
            Hash -> Secret -> Secret -> Secret
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ba -> ba
hkdfExtract
                Hash
usedHash
                (Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
sec ByteString
"derived" TranscriptHash
th)
                Secret
ecdhe
    let clientHandshakeSecret = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
handshakeSecret ByteString
"c hs traffic" TranscriptHash
hChSh
        serverHandshakeSecret = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
handshakeSecret ByteString
"s hs traffic" TranscriptHash
hChSh
    let shts =
            Secret -> ServerTrafficSecret HandshakeSecret
forall a. Secret -> ServerTrafficSecret a
ServerTrafficSecret Secret
serverHandshakeSecret :: ServerTrafficSecret HandshakeSecret
        chts =
            Secret -> ClientTrafficSecret HandshakeSecret
forall a. Secret -> ClientTrafficSecret a
ClientTrafficSecret Secret
clientHandshakeSecret :: ClientTrafficSecret HandshakeSecret
    logKey ctx shts
    logKey ctx chts
    return $ SecretTriple (BaseSecret handshakeSecret) chts shts
  where
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice

calculateApplicationSecret
    :: Context
    -> CipherChoice
    -> BaseSecret HandshakeSecret
    -> TranscriptHash
    -> IO (SecretTriple ApplicationSecret)
calculateApplicationSecret :: Context
-> CipherChoice
-> BaseSecret HandshakeSecret
-> TranscriptHash
-> IO (SecretTriple ApplicationSecret)
calculateApplicationSecret Context
ctx CipherChoice
choice (BaseSecret Secret
sec) TranscriptHash
hChSf = do
    let th :: TranscriptHash
th = ByteString -> TranscriptHash
TranscriptHash (ByteString -> TranscriptHash) -> ByteString -> TranscriptHash
forall a b. (a -> b) -> a -> b
$ Hash -> ByteString -> ByteString
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
usedHash ByteString
""
        applicationSecret :: Secret
applicationSecret =
            Hash -> Secret -> Secret -> Secret
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ba -> ba
hkdfExtract
                Hash
usedHash
                (Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
sec ByteString
"derived" TranscriptHash
th)
                Secret
zero
    let clientApplicationSecret0 :: Secret
clientApplicationSecret0 = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
applicationSecret ByteString
"c ap traffic" TranscriptHash
hChSf
        serverApplicationSecret0 :: Secret
serverApplicationSecret0 = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
applicationSecret ByteString
"s ap traffic" TranscriptHash
hChSf
        exporterSecret :: Secret
exporterSecret = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
applicationSecret ByteString
"exp master" TranscriptHash
hChSf
    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Secret -> TLSSt ()
setTLS13ExporterSecret Secret
exporterSecret
    let sts0 :: ServerTrafficSecret ApplicationSecret
sts0 =
            Secret -> ServerTrafficSecret ApplicationSecret
forall a. Secret -> ServerTrafficSecret a
ServerTrafficSecret Secret
serverApplicationSecret0
                :: ServerTrafficSecret ApplicationSecret
    let cts0 :: ClientTrafficSecret ApplicationSecret
cts0 =
            Secret -> ClientTrafficSecret ApplicationSecret
forall a. Secret -> ClientTrafficSecret a
ClientTrafficSecret Secret
clientApplicationSecret0
                :: ClientTrafficSecret ApplicationSecret
    Context -> ServerTrafficSecret ApplicationSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx ServerTrafficSecret ApplicationSecret
sts0
    Context -> ClientTrafficSecret ApplicationSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx ClientTrafficSecret ApplicationSecret
cts0
    SecretTriple ApplicationSecret
-> IO (SecretTriple ApplicationSecret)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SecretTriple ApplicationSecret
 -> IO (SecretTriple ApplicationSecret))
-> SecretTriple ApplicationSecret
-> IO (SecretTriple ApplicationSecret)
forall a b. (a -> b) -> a -> b
$ BaseSecret ApplicationSecret
-> ClientTrafficSecret ApplicationSecret
-> ServerTrafficSecret ApplicationSecret
-> SecretTriple ApplicationSecret
forall a.
BaseSecret a
-> ClientTrafficSecret a -> ServerTrafficSecret a -> SecretTriple a
SecretTriple (Secret -> BaseSecret ApplicationSecret
forall a. Secret -> BaseSecret a
BaseSecret Secret
applicationSecret) ClientTrafficSecret ApplicationSecret
cts0 ServerTrafficSecret ApplicationSecret
sts0
  where
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice
    zero :: Secret
zero = CipherChoice -> Secret
cZero CipherChoice
choice

calculateResumptionSecret
    :: Context
    -> CipherChoice
    -> BaseSecret ApplicationSecret
    -> IO (BaseSecret ResumptionSecret)
calculateResumptionSecret :: Context
-> CipherChoice
-> BaseSecret ApplicationSecret
-> IO (BaseSecret ResumptionSecret)
calculateResumptionSecret Context
ctx CipherChoice
choice (BaseSecret Secret
sec) = do
    hChCf <- Context -> String -> IO TranscriptHash
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m TranscriptHash
transcriptHash Context
ctx String
"CH..CF"
    let resumptionSecret = Hash -> Secret -> ByteString -> TranscriptHash -> Secret
deriveSecret Hash
usedHash Secret
sec ByteString
"res master" TranscriptHash
hChCf
    return $ BaseSecret resumptionSecret
  where
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice

derivePSK
    :: CipherChoice -> BaseSecret ResumptionSecret -> TicketNonce -> ByteString
derivePSK :: CipherChoice
-> BaseSecret ResumptionSecret -> TicketNonce -> ByteString
derivePSK CipherChoice
choice (BaseSecret Secret
sec) (TicketNonce ByteString
nonce) =
    Hash -> Secret -> ByteString -> ByteString -> Int -> ByteString
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> Secret -> ByteString -> ByteString -> Int -> ba
hkdfExpandLabel Hash
usedHash Secret
sec ByteString
"resumption" ByteString
nonce Int
hashSize
  where
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash

----------------------------------------------------------------

checkClientKeyShareKeyLength :: KeyShareEntry -> Bool
checkClientKeyShareKeyLength :: KeyShareEntry -> Bool
checkClientKeyShareKeyLength KeyShareEntry
ks = Group -> Int
clientKeyShareKeyLength Group
grp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
B.length ByteString
key
  where
    grp :: Group
grp = KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ks
    key :: ByteString
key = KeyShareEntry -> ByteString
keyShareEntryKeyExchange KeyShareEntry
ks

{- FOURMOLU_DISABLE -}
clientKeyShareKeyLength :: Group -> Int
clientKeyShareKeyLength :: Group -> Int
clientKeyShareKeyLength Group
P256   = Int
65  -- 32 * 2 + 1
clientKeyShareKeyLength Group
P384   = Int
97  -- 48 * 2 + 1
clientKeyShareKeyLength Group
P521   = Int
133 -- 66 * 2 + 1
clientKeyShareKeyLength Group
X25519 = Int
32
clientKeyShareKeyLength Group
X448   = Int
56
clientKeyShareKeyLength Group
FFDHE2048 = Int
256
clientKeyShareKeyLength Group
FFDHE3072 = Int
384
clientKeyShareKeyLength Group
FFDHE4096 = Int
512
clientKeyShareKeyLength Group
FFDHE6144 = Int
768
clientKeyShareKeyLength Group
FFDHE8192 = Int
1024
clientKeyShareKeyLength Group
MLKEM512  = Int
800
clientKeyShareKeyLength Group
MLKEM768  = Int
1184
clientKeyShareKeyLength Group
MLKEM1024 = Int
1568
clientKeyShareKeyLength Group
X25519MLKEM768 = Int
1216
clientKeyShareKeyLength Group
P256MLKEM768   = Int
1249
clientKeyShareKeyLength Group
P384MLKEM1024  = Int
1665
clientKeyShareKeyLength Group
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"clientKeyShareKeyLength"
{- FOURMOLU_ENABLE -}

checkServerKeyShareKeyLength :: KeyShareEntry -> Bool
checkServerKeyShareKeyLength :: KeyShareEntry -> Bool
checkServerKeyShareKeyLength KeyShareEntry
ks = Group -> Int
serverKeyShareKeyLength Group
grp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
B.length ByteString
key
  where
    grp :: Group
grp = KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ks
    key :: ByteString
key = KeyShareEntry -> ByteString
keyShareEntryKeyExchange KeyShareEntry
ks

{- FOURMOLU_DISABLE -}
serverKeyShareKeyLength :: Group -> Int
serverKeyShareKeyLength :: Group -> Int
serverKeyShareKeyLength Group
P256   = Int
65  -- 32 * 2 + 1
serverKeyShareKeyLength Group
P384   = Int
97  -- 48 * 2 + 1
serverKeyShareKeyLength Group
P521   = Int
133 -- 66 * 2 + 1
serverKeyShareKeyLength Group
X25519 = Int
32
serverKeyShareKeyLength Group
X448   = Int
56
serverKeyShareKeyLength Group
FFDHE2048 = Int
256
serverKeyShareKeyLength Group
FFDHE3072 = Int
384
serverKeyShareKeyLength Group
FFDHE4096 = Int
512
serverKeyShareKeyLength Group
FFDHE6144 = Int
768
serverKeyShareKeyLength Group
FFDHE8192 = Int
1024
serverKeyShareKeyLength Group
MLKEM512  = Int
768
serverKeyShareKeyLength Group
MLKEM768  = Int
1088
serverKeyShareKeyLength Group
MLKEM1024 = Int
1568
serverKeyShareKeyLength Group
X25519MLKEM768 = Int
1120
serverKeyShareKeyLength Group
P256MLKEM768   = Int
1153
serverKeyShareKeyLength Group
P384MLKEM1024  = Int
1665
serverKeyShareKeyLength Group
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"clientKeyShareKeyLength"
{- FOURMOLU_ENABLE -}

setRTT :: Context -> Millisecond -> IO ()
setRTT :: Context -> Millisecond -> IO ()
setRTT Context
ctx Millisecond
chSentTime = do
    shRecvTime <- IO Millisecond
getCurrentTimeFromBase
    let rtt' = Millisecond
shRecvTime Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
- Millisecond
chSentTime
        rtt = if Millisecond
rtt' Millisecond -> Millisecond -> Bool
forall a. Eq a => a -> a -> Bool
== Millisecond
0 then Millisecond
10 else Millisecond
rtt'
    modifyTLS13State ctx $ \TLS13State
st -> TLS13State
st{tls13stRTT = rtt}

computeConfirm
    :: (MonadFail m, MonadIO m)
    => Context -> Hash -> ServerHello -> ByteString -> m ByteString
computeConfirm :: forall (m :: * -> *).
(MonadFail m, MonadIO m) =>
Context -> Hash -> ServerHello -> ByteString -> m ByteString
computeConfirm Context
ctx Hash
usedHash ServerHello
sh ByteString
label = do
    (CH{..}, _b) <- Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString])
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString]))
-> m (Maybe (ClientHello, [ByteString]))
-> m (ClientHello, [ByteString])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Maybe (ClientHello, [ByteString]))
-> m (Maybe (ClientHello, [ByteString]))
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Context
-> HandshakeM (Maybe (ClientHello, [ByteString]))
-> IO (Maybe (ClientHello, [ByteString]))
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe (ClientHello, [ByteString]))
getClientHello)
    TranscriptHash echConf <-
        transcriptHashWith ctx "ECH acceptance" $ encodeHandshake13 $ ServerHello13 sh
    let prk = Hash -> ByteString -> ByteString -> ByteString
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ba -> ba
hkdfExtract Hash
usedHash ByteString
"" (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientRandom -> ByteString
unClientRandom ClientRandom
chRandom
    return $ hkdfExpandLabel usedHash (convert prk) label echConf 8

----------------------------------------------------------------

setServerHelloParameters13
    :: Context -> Cipher -> Bool -> IO (Either TLSError ())
setServerHelloParameters13 :: Context -> Cipher -> Bool -> IO (Either TLSError ())
setServerHelloParameters13 Context
ctx Cipher
cipher Bool
isHRR = do
    Context -> String -> Hash -> Bool -> IO ()
transitTranscriptHash Context
ctx String
"transit" (Cipher -> Hash
cipherHash Cipher
cipher) Bool
isHRR
    Context
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM (Either TLSError ()) -> IO (Either TLSError ()))
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ do
        hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
        case hstPendingCipher hst of
            Maybe Cipher
Nothing -> do
                HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
                    HandshakeState
hst
                        { hstPendingCipher = Just cipher
                        , hstPendingCompression = nullCompression
                        }
                Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
            Just Cipher
oldcipher
                | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
                | Bool
otherwise ->
                    Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$
                        TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$
                            String -> AlertDescription -> TLSError
Error_Protocol String
"TLS 1.3 cipher changed after hello retry" AlertDescription
IllegalParameter

-- | TLS13 handshake wrap up & clean up.  Contrary to
-- @finishHandshake12@, this does not handle session, which is managed
-- separately for TLS 1.3.  This does not reset byte counters because
-- renegotiation is not allowed.  And a few more state attributes are
-- preserved, necessary for TLS13 handshake modes, session tickets and
-- post-handshake authentication.
finishHandshake13 :: Context -> IO ()
finishHandshake13 :: Context -> IO ()
finishHandshake13 Context
ctx = do
    -- forget most handshake data
    MVar (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar (Maybe HandshakeState)
ctxHandshakeState Context
ctx) ((Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ())
-> (Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \case
        Maybe HandshakeState
Nothing -> Maybe HandshakeState -> IO (Maybe HandshakeState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe HandshakeState
forall a. Maybe a
Nothing
        Just HandshakeState
hshake ->
            Maybe HandshakeState -> IO (Maybe HandshakeState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState -> IO (Maybe HandshakeState))
-> Maybe HandshakeState -> IO (Maybe HandshakeState)
forall a b. (a -> b) -> a -> b
$
                HandshakeState -> Maybe HandshakeState
forall a. a -> Maybe a
Just
                    (Version -> ClientRandom -> HandshakeState
newEmptyHandshake (HandshakeState -> Version
hstClientVersion HandshakeState
hshake) (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hshake))
                        { hstServerRandom = hstServerRandom hshake
                        , hstMainSecret = hstMainSecret hshake
                        , hstSupportedGroup = hstSupportedGroup hshake
                        , hstTransHashState = hstTransHashState hshake
                        , hstTLS13HandshakeMode = hstTLS13HandshakeMode hshake
                        , hstTLS13RTT0Status = hstTLS13RTT0Status hshake
                        , hstTLS13ResumptionSecret = hstTLS13ResumptionSecret hshake
                        , hstTLS13ECHAccepted = hstTLS13ECHAccepted hshake
                        }
    -- forget handshake data stored in TLS state
    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Maybe KeyShare -> TLSSt ()
setTLS13KeyShare Maybe KeyShare
forall a. Maybe a
Nothing
        Maybe PreSharedKey -> TLSSt ()
setTLS13PreSharedKey Maybe PreSharedKey
forall a. Maybe a
Nothing
    -- mark the secure connection up and running.
    Context -> Established -> IO ()
setEstablished Context
ctx Established
Established