-- |
-- Module      : Fusion
-- License     : BSD-3-Clause
-- Copyright   : (c) 2026 Olivier Chéron
--
-- Infrastructure to decrease intermediate allocations and prefer in-place
-- mutation when possible
--
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Fusion
    ( Fusion(..), MapF(..)
    , Context, runContext, newContext, thawContext, mapContext, modifyContext
    , foldContext, seqContext
    ) where

import Control.Monad ( forM_, (>=>) )
import Control.Monad.ST

-- class of values that can be mutated in the ST monad
class Fusion a where
    type Mut a s = mut | mut -> a
    newF :: ST s (Mut a s)
    thawF :: a -> ST s (Mut a s)
    unsafeFreezeF :: Mut a s -> ST s a

-- a transformation step in the fusion pipeline, with two implementations
-- provided: one that operates on an existing mutation context, and one that
-- initiates a new context from the input
data MapF a b = MapF
    { forall a b. MapF a b -> forall s. Mut a s -> ST s (Mut b s)
mapUpdate :: forall s. Mut a s -> ST s (Mut b s)
    , forall a b. MapF a b -> forall s. a -> ST s (Mut b s)
mapInit :: forall s. a -> ST s (Mut b s)
    }

-- MapF is almost a category except for the 'Fusion' constraint on objects
--
-- idMapF :: Fusion a => MapF a a
-- idMapF = MapF { mapUpdate = pure, mapInit = thawF }

composeMapF :: MapF b c -> MapF a b -> MapF a c
composeMapF :: forall b c a. MapF b c -> MapF a b -> MapF a c
composeMapF MapF b c
m2 MapF a b
m1 = MapF
    { mapUpdate :: forall s. Mut a s -> ST s (Mut c s)
mapUpdate = MapF a b -> forall s. Mut a s -> ST s (Mut b s)
forall a b. MapF a b -> forall s. Mut a s -> ST s (Mut b s)
mapUpdate MapF a b
m1 (Mut a s -> ST s (Mut b s))
-> (Mut b s -> ST s (Mut c s)) -> Mut a s -> ST s (Mut c s)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> MapF b c -> forall s. Mut b s -> ST s (Mut c s)
forall a b. MapF a b -> forall s. Mut a s -> ST s (Mut b s)
mapUpdate MapF b c
m2
    , mapInit :: forall s. a -> ST s (Mut c s)
mapInit = MapF a b -> forall s. a -> ST s (Mut b s)
forall a b. MapF a b -> forall s. a -> ST s (Mut b s)
mapInit MapF a b
m1 (a -> ST s (Mut b s))
-> (Mut b s -> ST s (Mut c s)) -> a -> ST s (Mut c s)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> MapF b c -> forall s. Mut b s -> ST s (Mut c s)
forall a b. MapF a b -> forall s. Mut a s -> ST s (Mut b s)
mapUpdate MapF b c
m2
    }

-- fusion context
newtype Context a = Context (forall s. ST s (Mut a s))

newContext :: Fusion a => Context a
newContext :: forall a. Fusion a => Context a
newContext = (forall s. ST s (Mut a s)) -> Context a
forall a. (forall s. ST s (Mut a s)) -> Context a
Context ST s (Mut a s)
forall s. ST s (Mut a s)
forall a s. Fusion a => ST s (Mut a s)
newF

thawContext :: Fusion a => a -> Context a
thawContext :: forall a. Fusion a => a -> Context a
thawContext a
a = (forall s. ST s (Mut a s)) -> Context a
forall a. (forall s. ST s (Mut a s)) -> Context a
Context ((forall s. ST s (Mut a s)) -> Context a)
-> (forall s. ST s (Mut a s)) -> Context a
forall a b. (a -> b) -> a -> b
$ a -> ST s (Mut a s)
forall s. a -> ST s (Mut a s)
forall a s. Fusion a => a -> ST s (Mut a s)
thawF a
a
{-# INLINE [0] thawContext #-}

modifyContext :: (forall s. Mut a s -> ST s ()) -> Context a -> Context a
modifyContext :: forall a. (forall s. Mut a s -> ST s ()) -> Context a -> Context a
modifyContext forall s. Mut a s -> ST s ()
f = (forall s. Mut a s -> ST s (Mut a s)) -> Context a -> Context a
forall a b.
(forall s. Mut a s -> ST s (Mut b s)) -> Context a -> Context b
bindContext ((forall s. Mut a s -> ST s (Mut a s)) -> Context a -> Context a)
-> (forall s. Mut a s -> ST s (Mut a s)) -> Context a -> Context a
forall a b. (a -> b) -> a -> b
$ \Mut a s
ma -> Mut a s -> ST s ()
forall s. Mut a s -> ST s ()
f Mut a s
ma ST s () -> ST s (Mut a s) -> ST s (Mut a s)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Mut a s -> ST s (Mut a s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Mut a s
ma

mapContext :: MapF a b -> Context a -> Context b
mapContext :: forall a b. MapF a b -> Context a -> Context b
mapContext MapF a b
m = (forall s. Mut a s -> ST s (Mut b s)) -> Context a -> Context b
forall a b.
(forall s. Mut a s -> ST s (Mut b s)) -> Context a -> Context b
bindContext (MapF a b -> forall s. Mut a s -> ST s (Mut b s)
forall a b. MapF a b -> forall s. Mut a s -> ST s (Mut b s)
mapUpdate MapF a b
m)
{-# INLINE [0] mapContext #-}

initContext :: MapF a b -> a -> Context b
initContext :: forall a b. MapF a b -> a -> Context b
initContext MapF a b
m a
a = (forall s. ST s (Mut b s)) -> Context b
forall a. (forall s. ST s (Mut a s)) -> Context a
Context ((forall s. ST s (Mut b s)) -> Context b)
-> (forall s. ST s (Mut b s)) -> Context b
forall a b. (a -> b) -> a -> b
$ MapF a b -> forall s. a -> ST s (Mut b s)
forall a b. MapF a b -> forall s. a -> ST s (Mut b s)
mapInit MapF a b
m a
a

bindContext :: (forall s. Mut a s -> ST s (Mut b s)) -> Context a -> Context b
bindContext :: forall a b.
(forall s. Mut a s -> ST s (Mut b s)) -> Context a -> Context b
bindContext forall s. Mut a s -> ST s (Mut b s)
f (Context forall s. ST s (Mut a s)
ctx) = (forall s. ST s (Mut b s)) -> Context b
forall a. (forall s. ST s (Mut a s)) -> Context a
Context ((forall s. ST s (Mut b s)) -> Context b)
-> (forall s. ST s (Mut b s)) -> Context b
forall a b. (a -> b) -> a -> b
$ ST s (Mut a s)
forall s. ST s (Mut a s)
ctx ST s (Mut a s) -> (Mut a s -> ST s (Mut b s)) -> ST s (Mut b s)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Mut a s -> ST s (Mut b s)
forall s. Mut a s -> ST s (Mut b s)
f

foldContext :: Foldable t => (forall s. b -> Mut a s -> ST s ()) -> Context a -> t b -> Context a
foldContext :: forall (t :: * -> *) b a.
Foldable t =>
(forall s. b -> Mut a s -> ST s ())
-> Context a -> t b -> Context a
foldContext forall s. b -> Mut a s -> ST s ()
f Context a
c t b
bs = (forall s. Mut a s -> ST s ()) -> Context a -> Context a
forall a. (forall s. Mut a s -> ST s ()) -> Context a -> Context a
modifyContext (\Mut a s
ma -> t b -> (b -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ t b
bs ((b -> ST s ()) -> ST s ()) -> (b -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \b
b -> b -> Mut a s -> ST s ()
forall s. b -> Mut a s -> ST s ()
f b
b Mut a s
ma) Context a
c

runContext :: Fusion a => Context a -> a
runContext :: forall a. Fusion a => Context a -> a
runContext (Context forall s. ST s (Mut a s)
ctx) = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST (ST s (Mut a s)
forall s. ST s (Mut a s)
ctx ST s (Mut a s) -> (Mut a s -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Mut a s -> ST s a
forall a s. Fusion a => Mut a s -> ST s a
forall s. Mut a s -> ST s a
unsafeFreezeF)
{-# INLINE [0] runContext #-}

seqContext :: a -> Context b -> Context b
seqContext :: forall a b. a -> Context b -> Context b
seqContext = a -> Context b -> Context b
forall a b. a -> b -> b
seq
{-# INLINE [0] seqContext #-}


-- Fusion rules
--
-- "thawContext/runContext" is the canonical optimization that eliminates an
-- allocation + value copy.  Instead, it sequences two transformations on the
-- same mutation context.
--
-- "mapContext/seqContext" moves strictness annotations upstream so that they
-- do not prevent other rules from firing.
--
-- "mapContext/mapContext" is not strictly needed: the function is ultimately
-- inlined to the same code.  But we keep it so that simplifications fire early
-- and do not wait for the final phase.
--
-- "mapContext/thawContext" is the rule that invokes mapInit instead of copying
-- the input and calling mapUpdate.

{-# RULES
"thawContext/runContext" [~0] forall c. thawContext (runContext c) = c
"mapContext/seqContext" [~0] forall a m c. mapContext m (seqContext a c) = seqContext a (mapContext m c)
"mapContext/mapContext" [~0] forall m1 m2 c. mapContext m2 (mapContext m1 c) = mapContext (composeMapF m2 m1) c
"mapContext/thawContext" [1] forall m a. mapContext m (thawContext a) = initContext m a
  #-}