In this post we use scala to define the monad laws based on the paper Monads for functional programming.
The monad laws are described in terms of left unit
, right unit
and associativity
.
The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.
unit a ⋆ λb.n = n[a/b]
flatMap(unit(a'))(b' => n) = n[a'/b']
We change the scala definition to use a'
and b'
. It makes it easier to
follow the derivation of the addition example introduced later in the text.
m ⋆ λa.unit a = m
flatMap(m)(a => unit(a)) = m
m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o
flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)
We prove that addition is associative using these laws.
First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.
sealed trait Term case class Con(i: Int) extends Term case class Add(t1: Term, t2: Term) extends Term type Id[A] = A implicit val idMonad = new Monad[Id] { def unit[A](a: A): Id[A] = a def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa) } // eval :: Term -> M Int def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match { // unit a case Con(a) => M.unit(a) // eval t ⋆ λa.eval u ⋆ λb.unit(a + b) case Add(t, u) => M.flatMap(eval(t)(M))(a => M.flatMap(eval(u)(M))(b => M.unit(a + b))) }
Next we show that Add(t, Add(u, v))
and Add(Add(t, u), v)
compute the same result.
We use the left unit and associativity laws to simplify the first expression.
We also remove the M
references in order to make the syntax clearer.
eval(Add(t, Add(u, v))) flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x))) flatMap(eval(t))(a => flatMap(flatMap(eval(u))(b => flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x))) // left unit flatmap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))
The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:
n = unit(a + x) a' = (b + c) b' = x n [a'/b'] = unit(a + (b + c))
The same approach is used in the second expression
eval(Add(Add(t, u), v)) flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c))) flatMap(flatMap(eval(t))(a => flatMap(eval(u))(b => unit(a + b))))(x => flatMap(eval(v))(c => unit(x + c))) // left unit flatMap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))
Again, we show the left unit substitution
n = flatMap(eval(v))(c => unit(x + c)) a' = (a + b) b' = x n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))
Finally we have
// expr 1 flatmap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c))))) // expr 2 flatMap(eval(t))(a => flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))
and the result follows by the associativity of addition.
We can use a toy example to exercise the code
val term1 = Add(Con(1), Add(Con(2), Con(3))) val term2 = Add(Add(Con(1), Con(2)), Con(3)) val res1 = eval(term1) println(res1) // 6 val res2 = eval(term2) println(res2) // 6
Next in the paper the operations map
and join
are introduced. Both can be
implemented in terms of flatMap
and unit
. We update our Monad definition
accordingly
trait Monad[F[_]] { def unit[A](a: A): F[A] def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] def map[A, B](fa: F[A])(f: A => B): F[B] = flatMap(fa)(a => unit(f(a))) def join[A](ffa: F[F[A]]): F[A] = flatMap(ffa)(fa => fa) }
An alternative to the previous definition is to use unit
, join
, and map
as
the basic operators.
// `unit`, `map`, and `join` as primitives. trait Monad2[F[_]] { def unit[A](a: A): F[A] def map[A, B](fa: F[A])(f: A => B): F[B] def join[A](ffa: F[F[A]]): F[A] def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = join(map(fa)(a => f(a))) }
©2023 daniberg.com