-- |
-- Module      : Matrix
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A matrix here is simply a vector of vectors.  The module also implements
-- two utility functions 'mulw' and 'muly' that multiply a matrix and a vector.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Matrix
    ( create, mulw, muly
#ifdef ML_KEM_TESTING
    , transpose
#endif
    ) where

import Base
import Math
import Vector (Vector)
import qualified Vector

create :: (KnownNat m, KnownNat n) => (Offset ty -> Offset (Vector n ty) -> ty) -> Vector m (Vector n ty)
create :: forall (m :: Nat) (n :: Nat) ty.
(KnownNat m, KnownNat n) =>
(Offset ty -> Offset (Vector n ty) -> ty) -> Vector m (Vector n ty)
create Offset ty -> Offset (Vector n ty) -> ty
f = (Offset (Vector n ty) -> Vector n ty) -> Vector m (Vector n ty)
forall (n :: Nat) a. KnownNat n => (Offset a -> a) -> Vector n a
Vector.create ((Offset (Vector n ty) -> Vector n ty) -> Vector m (Vector n ty))
-> (Offset (Vector n ty) -> Vector n ty) -> Vector m (Vector n ty)
forall a b. (a -> b) -> a -> b
$ \Offset (Vector n ty)
j -> (Offset ty -> ty) -> Vector n ty
forall (n :: Nat) a. KnownNat n => (Offset a -> a) -> Vector n a
Vector.create (Offset ty -> Offset (Vector n ty) -> ty
`f` Offset (Vector n ty)
j)
{-# INLINE create #-}

index :: Vector m (Vector n ty) -> Offset ty -> Offset (Vector n ty) -> ty
index :: forall (m :: Nat) (n :: Nat) ty.
Vector m (Vector n ty) -> Offset ty -> Offset (Vector n ty) -> ty
index Vector m (Vector n ty)
a Offset ty
i Offset (Vector n ty)
j = Vector n ty -> Offset ty -> ty
forall (n :: Nat) a. Vector n a -> Offset a -> a
Vector.index (Vector m (Vector n ty) -> Offset (Vector n ty) -> Vector n ty
forall (n :: Nat) a. Vector n a -> Offset a -> a
Vector.index Vector m (Vector n ty)
a Offset (Vector n ty)
j) Offset ty
i

mulw :: (KnownNat n, BiMulAdd b a) => Vector m (Vector n b) -> Vector m a -> Vector n a -> Vector n a
mulw :: forall (n :: Nat) b a (m :: Nat).
(KnownNat n, BiMulAdd b a) =>
Vector m (Vector n b) -> Vector m a -> Vector n a -> Vector n a
mulw Vector m (Vector n b)
a !Vector m a
u !Vector n a
b = (Offset a -> a) -> Vector n a
forall (n :: Nat) a. KnownNat n => (Offset a -> a) -> Vector n a
Vector.create ((Offset a -> a) -> Vector n a) -> (Offset a -> a) -> Vector n a
forall a b. (a -> b) -> a -> b
$ \(Offset Int
i) ->
    (Offset (ZonkAny 0) -> a -> (b, a)) -> a -> Vector m a -> a
forall b a ty t (n :: Nat).
BiMulAdd b a =>
(Offset ty -> t -> (b, a)) -> a -> Vector n t -> a
Vector.biMulFoldIndexWith (\(Offset Int
j) a
vu -> (Vector m (Vector n b) -> Offset b -> Offset (Vector n b) -> b
forall (m :: Nat) (n :: Nat) ty.
Vector m (Vector n ty) -> Offset ty -> Offset (Vector n ty) -> ty
index Vector m (Vector n b)
a (Int -> Offset b
forall ty. Int -> Offset ty
Offset Int
i) (Int -> Offset (Vector n b)
forall ty. Int -> Offset ty
Offset Int
j), a
vu)) (Vector n a -> Offset a -> a
forall (n :: Nat) a. Vector n a -> Offset a -> a
Vector.index Vector n a
b (Int -> Offset a
forall ty. Int -> Offset ty
Offset Int
i)) Vector m a
u
{-# INLINE mulw #-}

muly :: BiMulAdd b a => Vector m (Vector n b) -> Vector n a -> Vector m a
muly :: forall b a (m :: Nat) (n :: Nat).
BiMulAdd b a =>
Vector m (Vector n b) -> Vector n a -> Vector m a
muly Vector m (Vector n b)
a !Vector n a
u = (Vector n b -> a) -> Vector m (Vector n b) -> Vector m a
forall a b. (a -> b) -> Vector m a -> Vector m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Vector n b -> Vector n a -> a
forall b a (n :: Nat).
BiMulAdd b a =>
Vector n b -> Vector n a -> a
`Vector.dot` Vector n a
u) Vector m (Vector n b)
a
{-# INLINE muly #-}

#ifdef ML_KEM_TESTING
transpose :: (KnownNat m, KnownNat n) => Vector m (Vector n ty) -> Vector n (Vector m ty)
transpose a = create $ \(Offset j) (Offset i) -> index a (Offset i) (Offset j)
#endif