{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, Trustworthy #-}

module Text.EditDistance.SquareSTUArray (
        levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
    ) where

import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities

import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST


levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
  where
    str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
    str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2

levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)

levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = do
    -- Create string arrays
    str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
    str2_array <- stringToArray str2 str2_len

    -- Create array of costs. Say we index it by (i, j) where i is the column index and j the row index.
    -- Rows correspond to characters of str2 and columns to characters of str1.
    cost_array <- newArray_ ((0, 0), (str1_len, str2_len)) :: ST s (STUArray s (Int, Int) Int)

    read_str1 <- unsafeReadArray' str1_array
    read_str2 <- unsafeReadArray' str2_array
    read_cost <- unsafeReadArray' cost_array
    write_cost <- unsafeWriteArray' cost_array

     -- Fill out the first row (j = 0)
    _ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) $ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in (Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
0) Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')

    -- Fill the remaining rows (j >= 1)
    _ <- (\Int -> Int -> ST s Int
f -> (Int -> Int -> ST s Int) -> Int -> [Int] -> ST s Int
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM Int -> Int -> ST s Int
f Int
0 [Int
1..Int
str2_len]) $ \Int
insertion_cost (!Int
j) -> do
        row_char <- Int -> ST s Char
read_str2 Int
j

        -- Initialize the first element of the row (i = 0)
        let insertion_cost' = Int
insertion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
        write_cost (0, j) insertion_cost'

        -- Fill the remaining elements of the row (i >= 1)
        loopM_ 1 str1_len $ \(!Int
i) -> do
            col_char <- Int -> ST s Char
read_str1 Int
i

            cost <- standardCosts costs read_cost row_char col_char (i, j)
            write_cost (i, j) cost

        return insertion_cost'

    -- Return an actual answer
    read_cost (str1_len, str2_len)


restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
  where
    str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
    str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2

restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)

restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2 = do
    -- Create string arrays
    str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
    str2_array <- stringToArray str2 str2_len

    -- Create array of costs. Say we index it by (i, j) where i is the column index and j the row index.
    -- Rows correspond to characters of str2 and columns to characters of str1.
    cost_array <- newArray_ ((0, 0), (str1_len, str2_len)) :: ST s (STUArray s (Int, Int) Int)

    read_str1 <- unsafeReadArray' str1_array
    read_str2 <- unsafeReadArray' str2_array
    read_cost <- unsafeReadArray' cost_array
    write_cost <- unsafeWriteArray' cost_array

     -- Fill out the first row (j = 0)
    _ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) $ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in (Int, Int) -> Int -> ST s ()
write_cost (Int
i, Int
0) Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')

    -- Fill out the second row (j = 1)
    when (str2_len > 0) $ do
        initial_row_char <- read_str2 1

        -- Initialize the first element of the second row (i = 0)
        write_cost (0, 1) (insertionCost costs initial_row_char)

        -- Initialize the remaining elements of the row (i >= 1)
        loopM_ 1 str1_len $ \(!Int
i) -> do
            col_char <- Int -> ST s Char
read_str1 Int
i

            cost <- standardCosts costs read_cost initial_row_char col_char (i, 1)
            write_cost (i, 1) cost

    -- Fill the remaining rows (j >= 2)
    loopM_ 2 str2_len (\(!Int
j) -> do
        row_char <- Int -> ST s Char
read_str2 Int
j
        prev_row_char <- read_str2 (j - 1)

        -- Initialize the first element of the row (i = 0)
        write_cost (0, j) (insertionCost costs row_char * j)

        -- Initialize the second element of the row (i = 1)
        when (str1_len > 0) $ do
            col_char <- read_str1 1

            cost <- standardCosts costs read_cost row_char col_char (1, j)
            write_cost (1, j) cost

        -- Fill the remaining elements of the row (i >= 2)
        loopM_ 2 str1_len (\(!Int
i) -> do
            col_char <- Int -> ST s Char
read_str1 Int
i
            prev_col_char <- read_str1 (i - 1)

            standard_cost <- standardCosts costs read_cost row_char col_char (i, j)
            cost <- if prev_row_char == col_char && prev_col_char == row_char
                    then do transpose_cost <- fmap (+ (transpositionCost costs col_char row_char)) $ read_cost (i - 2, j - 2)
                            return (standard_cost `min` transpose_cost)
                    else return standard_cost
            write_cost (i, j) cost))

    -- Return an actual answer
    read_cost (str1_len, str2_len)


{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> ((Int, Int) -> ST s Int) -> Char -> Char -> (Int, Int) -> ST s Int
standardCosts :: forall s.
EditCosts
-> ((Int, Int) -> ST s Int)
-> Char
-> Char
-> (Int, Int)
-> ST s Int
standardCosts !EditCosts
costs (Int, Int) -> ST s Int
read_cost !Char
row_char !Char
col_char (!Int
i, !Int
j) = do
    deletion_cost  <- (Int -> Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char))  (ST s Int -> ST s Int) -> ST s Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> ST s Int
read_cost (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
j)
    insertion_cost <- fmap (+ (insertionCost costs row_char)) $ read_cost (i, j - 1)
    subst_cost     <- fmap (+ if row_char == col_char
                                then 0
                                else (substitutionCost costs col_char row_char))
                           (read_cost (i - 1, j - 1))
    return $ deletion_cost `min` insertion_cost `min` subst_cost