In this post we use scala to define the monad laws based on the paper Monads for functional programming.

The monad laws are described in terms of left unit, righ unit and associativity.

The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.

  1. Left unit
unit a ⋆ λb.n = n[a/b]
flatMap(unit(a'))(b' => n) = n[a'/b']

We change the scala definition to use a' and b'. It makes it easier to follow the derivation of the addition example introduced later in the text.

  1. Right unit
m ⋆ λa.unit a = m
flatMap(m)(a => unit(a)) = m
  1. Associativity
m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o
flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)

Associativity of addition

We prove that addition is associative using these laws.

First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.

sealed trait Term
case class Con(i: Int) extends Term
case class Add(t1: Term, t2: Term) extends Term

type Id[A] = A

implicit val idMonad = new Monad[Id] {
  def unit[A](a: A): Id[A] = a
  def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)

// eval :: Term -> M Int
def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
  // unit a
  case Con(a) => M.unit(a)
  // eval t ⋆ λa.eval u ⋆ λb.unit(a + b)
  case Add(t, u) => M.flatMap(eval(t)(M))(a =>
    M.flatMap(eval(u)(M))(b => M.unit(a + b)))

Next we show that Add(t, Add(u, v)) and Add(Add(t, u), v) compute the same result.

We use the left unit and associativity laws to simplify the first expression.

We also remove the M references in order to make the syntax clearer.

eval(Add(t, Add(u, v)))

flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x)))

flatMap(eval(t))(a =>
  flatMap(flatMap(eval(u))(b =>
    flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x)))

// left unit
flatmap(eval(t))(a =>
  flatMap(eval(u))(b =>
    flatMap(eval(v))(c => unit(a + (b + c)))))

The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:

n = unit(a + x)
a' = (b + c)
b' = x
n [a'/b'] = unit(a + (b + c))

The same approach is used in the second expression:

eval(Add(Add(t, u), v))

flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c)))

flatMap(flatMap(eval(t))(a =>
  flatMap(eval(u))(b => unit(a + b))))(x =>
    flatMap(eval(v))(c => unit(x + c)))

// left unit
flatMap(eval(t))(a =>
  flatMap(eval(u))(b =>
    flatMap(eval(v))(c => unit((a + b) + c))))

Again, we show the left unit substitution:

n = flatMap(eval(v))(c => unit(x + c))
a' = (a + b)
b' = x
n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))

Finally we have

// expr 1
flatmap(eval(t))(a =>
  flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))

// expr 2
flatMap(eval(t))(a =>
  flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))

and the result follows by the associativity of addition.

Toy example

We can use a toy example to exercise the code

val term1 = Add(Con(1), Add(Con(2), Con(3)))
val term2 = Add(Add(Con(1), Con(2)), Con(3))

val res1 = eval(term1)
println(res1) // 6

val res2 = eval(term2)
println(res2) // 6

Monad definition revisited

Next in the paper the operations map and join are introduced. Both can be implemented in terms of flatMap and unit. We update our Monad definition accordingly:

trait Monad[F[_]] {
  def unit[A](a: A): F[A]
  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]

  def map[A, B](fa: F[A])(f: A => B): F[B] =
    flatMap(fa)(a => unit(f(a)))

  def join[A](ffa: F[F[A]]): F[A] =
    flatMap(ffa)(fa => fa)

An alternative to the previous definition is to use unit, join, and map as the basic operators.

// `unit`, `map`, and `join` as primitives.
trait Monad2[F[_]] {

  def unit[A](a: A): F[A]
  def map[A, B](fa: F[A])(f: A => B): F[B]
  def join[A](ffa: F[F[A]]): F[A]

  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] =
    join(map(fa)(a => f(a)))