In this post we use scala to define the monad laws based on the paper Monads for functional programming.

The monad laws are described in terms of `left unit`

, `right unit`

and `associativity`

.

The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.

unit a ⋆ λb.n = n[a/b]

flatMap(unit(a'))(b' => n) = n[a'/b']

We change the scala definition to use `a'`

and `b'`

. It makes it easier to
follow the derivation of the addition example introduced later in the text.

m ⋆ λa.unit a = m

flatMap(m)(a => unit(a)) = m

m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o

flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)

We prove that addition is associative using these laws.

First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.

sealed trait Term case class Con(i: Int) extends Term case class Add(t1: Term, t2: Term) extends Term type Id[A] = A implicit val idMonad = new Monad[Id] { def unit[A](a: A): Id[A] = a def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa) } // eval :: Term -> M Int def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match { // unit a case Con(a) => M.unit(a) // eval t ⋆ λa.eval u ⋆ λb.unit(a + b) case Add(t, u) => M.flatMap(eval(t)(M))(a => M.flatMap(eval(u)(M))(b => M.unit(a + b))) }

Next we show that `Add(t, Add(u, v))`

and `Add(Add(t, u), v)`

compute the same result.

We use the left unit and associativity laws to simplify the first expression.

We also remove the `M`

references in order to make the syntax clearer.

eval(Add(t, Add(u, v))) flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x))) flatMap(eval(t))(a => flatMap(flatMap(eval(u))(b => flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x))) // left unit flatmap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))

The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:

n = unit(a + x) a' = (b + c) b' = x n [a'/b'] = unit(a + (b + c))

The same approach is used in the second expression

eval(Add(Add(t, u), v)) flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c))) flatMap(flatMap(eval(t))(a => flatMap(eval(u))(b => unit(a + b))))(x => flatMap(eval(v))(c => unit(x + c))) // left unit flatMap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))

Again, we show the left unit substitution

n = flatMap(eval(v))(c => unit(x + c)) a' = (a + b) b' = x n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))

Finally we have

// expr 1 flatmap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c))))) // expr 2 flatMap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))

and the result follows by the associativity of addition.

We can use a toy example to exercise the code

val term1 = Add(Con(1), Add(Con(2), Con(3))) val term2 = Add(Add(Con(1), Con(2)), Con(3)) val res1 = eval(term1) println(res1) // 6 val res2 = eval(term2) println(res2) // 6

Next in the paper the operations `map`

and `join`

are introduced. Both can be
implemented in terms of `flatMap`

and `unit`

. We update our Monad definition
accordingly

trait Monad[F[_]] { def unit[A](a: A): F[A] def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] def map[A, B](fa: F[A])(f: A => B): F[B] = flatMap(fa)(a => unit(f(a))) def join[A](ffa: F[F[A]]): F[A] = flatMap(ffa)(fa => fa) }

An alternative to the previous definition is to use `unit`

, `join`

, and `map`

as
the basic operators.

// `unit`, `map`, and `join` as primitives. trait Monad2[F[_]] { def unit[A](a: A): F[A] def map[A, B](fa: F[A])(f: A => B): F[B] def join[A](ffa: F[F[A]]): F[A] def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = join(map(fa)(a => f(a))) }

©2023 daniberg.com