module SemanticsFunctions where

import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Prelude hiding ((>>=), pure)
-- In the future, we will learn that they are methods of some standard type
-- classes.  For now, I don't want to talk about those classes, I pretend they
-- are specific to my interpreter.

-- Result type of interpreter: Possibility of errors.
-- More general/abstract thinking: Type of interpreter, model of possible effects.
data ExprInterp a = Error ErrorType | Success a
    deriving Show
data ErrorType = TypeError | VarNotFound
    deriving Show

-- Give an answer. Simply "Success" for ExprInterp, but can be more complex for
-- other interpreters.
pure :: a -> ExprInterp a
pure a = Success a

-- Raise an error. Simply "Error" for ExprInterp, but can be more complex for
-- other interpreters.
raise :: ErrorType -> ExprInterp a
raise e = Error e

-- Sequential composition of two interpreter stages: Check success of
-- 1st stage, and if success, pass answer to 2nd stage.
(>>=) :: ExprInterp a -> (a -> ExprInterp b) -> ExprInterp b
Error msg >>= _ = Error msg
Success a >>= k = k a


data Expr
    = Num Integer
    | Bln Bool
    | Var String
    | Prim2 Op2 Expr Expr         -- Prim2 op operand operand
    | Cond Expr Expr Expr         -- Cond test then-branch else-branch
    | Let [(String, Expr)] Expr   -- Let [(name, rhs), ...] eval-me
    | Lambda String Expr          -- Lambda argname body
    | App Expr Expr               -- App func param
    | Rec String String Expr      -- Rec funcname argname funcbody
    deriving (Eq, Show)

data Op2 = Eq | Plus | Mul
    deriving (Eq, Show)

-- The type of possible values from the interpreter.
data Value = VN Integer
           | VB Bool
           | VClosure (Map String Value) String Expr
           | VRecClosure (Map String Value) String String Expr
    deriving (Eq, Show)

mainInterp :: Expr -> ExprInterp Value
mainInterp expr = interp expr Map.empty

-- Helper to expect the VN case (failure if not) and return the integer.
intOrDie :: Value -> ExprInterp Integer
intOrDie (VN i) = pure i
intOrDie _ = raise TypeError

interp :: Expr -> Map String Value -> ExprInterp Value

interp (Num i) _ = pure (VN i)

interp (Bln b) _ = pure (VB b)

interp (Prim2 Plus e1 e2) env =
    interp e1 env
    >>= \a -> intOrDie a
    >>= \i -> interp e2 env
    >>= \b -> intOrDie b
    >>= \j -> pure (VN (i+j))

interp (Prim2 Mul e1 e2) env =
    interp e1 env
    >>= \a -> intOrDie a
    >>= \i -> interp e2 env
    >>= \b -> intOrDie b
    >>= \j -> pure (VN (i*j))

interp (Prim2 Eq e1 e2) env =
    interp e1 env
    >>= \a -> intOrDie a
    >>= \i -> interp e2 env
    >>= \b -> intOrDie b
    >>= \j -> pure (VB (i == j))

interp (Cond test eThen eElse) env =
    interp test env
    >>= \a -> case a of
      VB True -> interp eThen env
      VB False -> interp eElse env
      _ -> raise TypeError

interp (Var v) env = case Map.lookup v env of
  Just a -> pure a
  Nothing -> raise VarNotFound

interp (Let eqns evalMe) env =
    extend eqns env
    >>= \env' -> interp evalMe env'
    -- Example:
    --    let x=2+3; y=x+4 in x+y
    -- -> x+y   (with x=5, y=9 in the larger environment env')
    -- "extend env eqns" builds env'
  where
    extend [] env = pure env
    extend ((v,rhs) : eqns) env =
        interp rhs env
        >>= \a ->
        let env' = Map.insert v a env
        in extend eqns env'

interp (Lambda v body) env = pure (VClosure env v body)

interp (App f e) env =
    interp f env
    >>= \c -> case c of
      VClosure fEnv v body ->
          interp e env
          >>= \eVal ->
          let bEnv = Map.insert v eVal fEnv  -- fEnv, not env
          in interp body bEnv
          -- E.g.,
          --    (\y -> 10+y) 17
          -- -> 10 + y      (but with y=17 in environment)
          --
      VRecClosure fEnv fName v body ->
          interp e env
          >>= \eVal ->
          let bEnv = Map.insert v eVal (Map.insert fName c fEnv)
          in interp body bEnv
      _ -> raise TypeError

interp (Rec f v fbody) env = pure (VRecClosure env f v fbody)


-- let { x=10; f = \y->x+y ; } in
-- let { x=5; } in
-- f 0
exampleScoping =
    Let [ ("x", Num 10)
        , ("f", Lambda "y" (Prim2 Plus (Var "x") (Var "y")))
        ]
        (Let [("x", Num 5)]
             (App (Var "f") (Num 0)))

-- (\x -> x x) (\x -> x x)
nonStop = Let [("diagonal", Lambda "x" (App (Var "x") (Var "x")))]
          (App (Var "diagonal") (Var "diagonal"))

-- Factorial using the diagonal technique.
-- let mkFac = \f -> \n -> if n=0 then 1 else n * (f f) (n-1)
-- in mkFac mkFac k
fac k = Let [ ( "mkFac"
              , Lambda "f" (
                    Lambda "n" (
                          Cond
                            (Prim2 Eq (Var "n") (Num 0))
                            (Num 1)
                            (Prim2 Mul
                              (Var "n")
                              (App (App (Var "f") (Var "f"))
                                (Prim2 Plus (Var "n") (Num (-1))))))))
            ]
        (App (App (Var "mkFac") (Var "mkFac")) (Num k))

-- Factorial using provided recursion construct.
-- (rec f n -> if n=0 then 1 else n * f (n-1)) k
fac2 k = App (Rec "f" "n"
               (Cond
                 (Prim2 Eq (Var "n") (Num 0))
                 (Num 1)
                 (Prim2 Mul
                   (Var "n")
                   (App (Var "f")
                     (Prim2 Plus (Var "n") (Num (-1)))))))
             (Num k)