# Monads for functional programming in scala

The paper Monads for functional programming develops the application of monads using as an example an evaluator of a simple language that performs divisions on integers. Three variations are given: Exceptions, State and Output.

Here we implement the examples of the paper using scala. The source code is in github.

We first define the language used by the evaluator.

``````sealed trait Term
case class Con(c: Int) extends Term
case class Div(l: Term, r: Term) extends Term
``````

We also include two examples. The first called `answer` describes a computation of the number 42. The second example, `error`, represents a division by zero.

``````object Evaluator {
val answer = Div(Div(Con(1972), Con(2)), Con(23))
val error = Div(Con(1), Con(0))
}
``````

Then, we define a Monad in terms of `unit` and `flatMap`.

``````trait Monad[F[_]] {
def unit[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
}
``````

Finally, we implement a base case of the evaluator relying in the `Id` monad.

``````object VariationZeroBasicEval {

// type M a = a
type Id[A] = A

def unit[A](a: A): Id[A] = a
def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
}

def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
case Con(c) => M.unit(c)
case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
M.flatMap(eval(u)(M)) { b =>
M.unit(a / b)
}
}
}

def run(): Unit = {
println(result) // 42
}

}
``````

In the next 3 examples we’ll modify our base implementation to handle Exceptions, State and Output.

The modifications follow the same pattern. Define a data type followed by its monad implementation and the eval function that takes our division program and an implicit monad implementation of our data type. Run the evaluator with the examples defined in the object Evaluator.

## Exceptions

The first example handles the case where we have a division by zero. We piggyback on scala’s Either data structure to define our new data type.

The only change we have to make in the `eval` implementation is how we handle divisions when the denominator is zero. In this situation we yield a `Left` which is represented as `Raise` in the paper.

``````/** We raise an `Exception` if we find a division by zero. Otherwise we return a
* value. */
object VariationOneExceptions {

// We piggyback on Either.
// Left and Right take the role of Raise and Return.
// data M a = Raise Exception | Return a
// type Exception = String
type Exceptional[A] = Either[String, A]

def raise[A](ex: String): F[A]
}

def unit[A](a: A): Exceptional[A] = Right(a)
def flatMap[A, B]
(fa: Exceptional[A])
(f: A => Exceptional[B]): Exceptional[B] = fa match {
case Left(e) => Left(e)
case Right(a) => f(a)
}
def raise[A](ex: String): Exceptional[A] = Left(ex)
}

def eval
(term: Term)
(implicit M: MonadException[Exceptional]): Exceptional[Int] = term match
{
case Con(c) => M.unit(c)
case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
M.flatMap(eval(u)(M)) { b =>
// We replace: `M.unit(a / b)` with
if (b == 0)
M.raise("divide by zero")
else
M.unit(a / b)
}
}
}

def run(): Unit = {
println(resultSuccess) // Right(42)
val resultFailure: Exceptional[Int] = eval(Evaluator.error)
println(resultFailure) // Left(divide by zero)
}

}
``````

## State

In the second example we use the `State` monad to count how many divisions were executed during evaluation. We add the operator `tick` to our monad implementation which is responsible to increase the count of division operations.

Again we have minimal changes in the eval implementation. This time all we need is to introduce one more `flatMap` operation over the `tick` operator before yielding the result of the division.

``````/** A computation accepts an initial state and returns a value. In our example
* the state is a count of the number divisions executed. */
object VariationTwoState {

// type M a = State -> (a, State)
// type State = Int
case class State[S, A](runS: S => (A, S))
type StateInt[A] = State[Int, A]

def tick(): F[Unit]
}

def unit[A](a: A): StateInt[A] = State(s => (a, s))
def flatMap[A, B](fa: StateInt[A])(f: A => StateInt[B]): StateInt[B] = State({ s =>
val (a, s2) = fa.runS(s)
f(a).runS(s2)
})
def tick(): StateInt[Unit] = State(s => ((), s + 1))
}

def eval(term: Term)(implicit M: MonadState[StateInt]): StateInt[Int] = term match {
case Con(c) => M.unit(c)
case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
M.flatMap(eval(u)(M)) { b =>
M.flatMap(M.tick())(_ => M.unit(a / b))
}
}
}

def run(): Unit = {
println(result.runS(0)) // (42,2)
}

}
``````

## Output

The third and final example generates along with the result of the computation a string representation of each step of the evaluation of our program.

The data type definition resembles the haskell syntax. We defined our data type as a pair where the first element is the string representation of the evaluation steps and the second element is the final result of the computation.

The monadic implementation introduces the operator `out` which is similar to the `tick` operator from the previous example if you squint hard enough.

The modifications to the eval function is again trivial. In case of `Con` we `flatMap` the result of the `out` operator before following with `unit` operator. In case of `Div` we do the same but we concatenate the result of the both terms of the division when building the output.

``````/** A computation consists of the output generated paired with the value
* returned. */
object VariationThreeOutput {

// Pretty close to the Haskell version:
// type M a = (Output, a)
// type Output = String
type Output = String
type OutputM[A] = (Output, A)

def out(out: Output): F[Unit]
}

def unit[A](a: A) = ("", a)
def flatMap[A, B]
(fa: OutputM[A])
(f: A => OutputM[B]): OutputM[B] =
{
val (x, a) = fa
val (y, b) = f(a)
(x + y, b)
}
def out(out: Output): OutputM[Unit] = (out, ())
}

def showterm(term: Term): String = term match {
case Con(c) => s"Con \$c"
case Div(t, u) => "Div (" + showterm(t) + " " + showterm(u) + ")"
}

def line(term: Term, a: Int): Output = {
"eval (" + showterm(term) + ") <= " + a.toString + "\n"
}

def eval
(term: Term)
{
term match {
case Con(a) =>
M.flatMap(M.out(line(Con(a), a)))(_ => M.unit(a))
case Div(t, u) =>
val (x, a) = eval(t)(M)
val (y, b) = eval(u)(M)
val r = a / b
M.flatMap(M.out(x + y + line(Div(t, u), r)))(_ => M.unit(r))
}
}

def run(): Unit = {