daniberg.com

posts github chip8 rtc guitar

Monads for functional programming in scala

The paper Monads for functional programming develops the application of monads using as an example an evaluator of a simple language that performs divisions on integers. Three variations are given: Exceptions, State and Output.

Here we implement the examples of the paper using scala. The source code is in github.

We first define the language used by the evaluator.

sealed trait Term
case class Con(c: Int) extends Term
case class Div(l: Term, r: Term) extends Term

We also include two examples. The first called answer describes a computation of the number 42. The second example, error, represents a division by zero.

object Evaluator {
  val answer = Div(Div(Con(1972), Con(2)), Con(23))
  val error = Div(Con(1), Con(0))
}

Then, we define a Monad in terms of unit and flatMap.

trait Monad[F[_]] {
  def unit[A](a: A): F[A]
  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
}

Finally, we implement a base case of the evaluator relying in the Id monad.

object VariationZeroBasicEval {

  // type M a = a
  type Id[A] = A

  implicit val idMonad = new Monad[Id] {
    def unit[A](a: A): Id[A] = a
    def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
  }

  def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
    case Con(c) => M.unit(c)
    case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
      M.flatMap(eval(u)(M)) { b =>
        M.unit(a / b)
      }
    }
  }

  def run(): Unit = {
    val result: Id[Int] = eval(Evaluator.answer)
    println(result) // 42
  }

}

In the next 3 examples we'll modify our base implementation to handle Exceptions, State and Output.

The modifications follow the same pattern. Define a data type followed by its monad implementation and the eval function that takes our division program and an implicit monad implementation of our data type. Run the evaluator with the examples defined in the object Evaluator.

Exceptions

The first example handles the case where we have a division by zero. We piggyback on scala's Either data structure to define our new data type.

The only change we have to make in the eval implementation is how we handle divisions when the denominator is zero. In this situation we yield a Left which is represented as Raise in the paper.

/** We raise an `Exception` if we find a division by zero. Otherwise we return a
  * value. */
object VariationOneExceptions {

  // We piggyback on Either.
  // Left and Right take the role of Raise and Return.
  // data M a = Raise Exception | Return a
  // type Exception = String
  type Exceptional[A] = Either[String, A]

  trait MonadException[F[_]] extends Monad[F] {
    def raise[A](ex: String): F[A]
  }

  implicit val exceptionalMonad = new MonadException[Exceptional] {
    def unit[A](a: A): Exceptional[A] = Right(a)
    def flatMap[A, B]
      (fa: Exceptional[A])
      (f: A => Exceptional[B]): Exceptional[B] = fa match {
        case Left(e) => Left(e)
        case Right(a) => f(a)
    }
    def raise[A](ex: String): Exceptional[A] = Left(ex)
  }

  def eval
    (term: Term)
    (implicit M: MonadException[Exceptional]): Exceptional[Int] = term match
  {
    case Con(c) => M.unit(c)
    case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
      M.flatMap(eval(u)(M)) { b =>
        // We replace: `M.unit(a / b)` with
        if (b == 0)
          M.raise("divide by zero")
        else
          M.unit(a / b)
      }
    }
  }

  def run(): Unit = {
    val resultSuccess: Exceptional[Int] = eval(Evaluator.answer)
    println(resultSuccess) // Right(42)
    val resultFailure: Exceptional[Int] = eval(Evaluator.error)
    println(resultFailure) // Left(divide by zero)
  }

}

State

In the second example we use the State monad to count how many divisions were executed during evaluation. We add the operator tick to our monad implementation which is responsible to increase the count of division operations.

Again we have minimal changes in the eval implementation. This time all we need is to introduce one more flatMap operation over the tick operator before yielding the result of the division.

/** A computation accepts an initial state and returns a value. In our example
  * the state is a count of the number divisions executed. */
object VariationTwoState {

  // type M a = State -> (a, State)
  // type State = Int
  case class State[S, A](runS: S => (A, S))
  type StateInt[A] = State[Int, A]

  trait MonadState[F[_]] extends Monad[F] {
    def tick(): F[Unit]
  }

  implicit val stateMonad = new MonadState[StateInt] {
    def unit[A](a: A): StateInt[A] = State(s => (a, s))
    def flatMap[A, B](fa: StateInt[A])(f: A => StateInt[B]): StateInt[B] = State({ s =>
      val (a, s2) = fa.runS(s)
      f(a).runS(s2)
    })
    def tick(): StateInt[Unit] = State(s => ((), s + 1))
  }

  def eval(term: Term)(implicit M: MonadState[StateInt]): StateInt[Int] = term match {
    case Con(c) => M.unit(c)
    case Div(t, u) => M.flatMap(eval(t)(M)) { a =>
      M.flatMap(eval(u)(M)) { b =>
        M.flatMap(M.tick())(_ => M.unit(a / b))
      }
    }
  }

  def run(): Unit = {
    val result: StateInt[Int] = eval(Evaluator.answer)
    println(result.runS(0)) // (42,2)
  }

}

Output

The third and final example generates along with the result of the computation a string representation of each step of the evaluation of our program.

The data type definition resembles the haskell syntax. We defined our data type as a pair where the first element is the string representation of the evaluation steps and the second element is the final result of the computation.

The monadic implementation introduces the operator out which is similar to the tick operator from the previous example if you squint hard enough.

The modifications to the eval function is again trivial. In case of Con we flatMap the result of the out operator before following with unit operator. In case of Div we do the same but we concatenate the result of the both terms of the division when building the output.

/** A computation consists of the output generated paired with the value
  * returned. */
object VariationThreeOutput {

  // Pretty close to the Haskell version:
  // type M a = (Output, a)
  // type Output = String
  type Output = String
  type OutputM[A] = (Output, A)

  trait MonadOutput[F[_]] extends Monad[F] {
    def out(out: Output): F[Unit]
  }

  implicit val outputMonad = new MonadOutput[OutputM] {
    def unit[A](a: A) = ("", a)
    def flatMap[A, B]
      (fa: OutputM[A])
      (f: A => OutputM[B]): OutputM[B] =
    {
      val (x, a) = fa
      val (y, b) = f(a)
      (x + y, b)
    }
    def out(out: Output): OutputM[Unit] = (out, ())
  }

  def showterm(term: Term): String = term match {
    case Con(c) => s"Con $c"
    case Div(t, u) => "Div (" + showterm(t) + " " + showterm(u) + ")"
  }

  def line(term: Term, a: Int): Output = {
    "eval (" + showterm(term) + ") <= " + a.toString + "\n"
  }

  def eval
    (term: Term)
    (implicit M: MonadOutput[OutputM]): OutputM[Int] =
  {
    term match {
      case Con(a) =>
        M.flatMap(M.out(line(Con(a), a)))(_ => M.unit(a))
      case Div(t, u) =>
        val (x, a) = eval(t)(M)
        val (y, b) = eval(u)(M)
        val r = a / b
        M.flatMap(M.out(x + y + line(Div(t, u), r)))(_ => M.unit(r))
    }
  }

  def run(): Unit = {
    val result: OutputM[Int] = eval(Evaluator.answer)
    println(result._1)
    // eval (Con 1972) <= 1972
    // eval (Con 2) <= 2
    // eval (Div (Con 1972 Con 2)) <= 986
    // eval (Con 23) <= 23
    // eval (Div (Div (Con 1972 Con 2) Con 23)) <= 42
  }

}

©2023 daniberg.com