# Monads for functional programming in scala part 2

In this post we use scala to define the monad laws based on the paper Monads for functional programming.

The monad laws are described in terms of `left unit`, `righ unit` and `associativity`.

The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.

1. Left unit
``````unit a ⋆ λb.n = n[a/b]
``````
``````flatMap(unit(a'))(b' => n) = n[a'/b']
``````

We change the scala definition to use `a'` and `b'`. It makes it easier to follow the derivation of the addition example introduced later in the text.

1. Right unit
``````m ⋆ λa.unit a = m
``````
``````flatMap(m)(a => unit(a)) = m
``````
1. Associativity
``````m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o
``````
``````flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)
``````

We prove that addition is associative using these laws.

First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.

``````sealed trait Term
case class Con(i: Int) extends Term
case class Add(t1: Term, t2: Term) extends Term

type Id[A] = A

def unit[A](a: A): Id[A] = a
def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
}

// eval :: Term -> M Int
def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
// unit a
case Con(a) => M.unit(a)
// eval t ⋆ λa.eval u ⋆ λb.unit(a + b)
case Add(t, u) => M.flatMap(eval(t)(M))(a =>
M.flatMap(eval(u)(M))(b => M.unit(a + b)))
}
``````

Next we show that `Add(t, Add(u, v))` and `Add(Add(t, u), v)` compute the same result.

We use the left unit and associativity laws to simplify the first expression.

We also remove the `M` references in order to make the syntax clearer.

``````eval(Add(t, Add(u, v)))

flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x)))

flatMap(eval(t))(a =>
flatMap(flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x)))

// left unit
flatmap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(a + (b + c)))))
``````

The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:

``````n = unit(a + x)
a' = (b + c)
b' = x
n [a'/b'] = unit(a + (b + c))
``````

The same approach is used in the second expression:

``````eval(Add(Add(t, u), v))

flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c)))

flatMap(flatMap(eval(t))(a =>
flatMap(eval(u))(b => unit(a + b))))(x =>
flatMap(eval(v))(c => unit(x + c)))

// left unit
flatMap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit((a + b) + c))))
``````

Again, we show the left unit substitution:

``````n = flatMap(eval(v))(c => unit(x + c))
a' = (a + b)
b' = x
n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))
``````

Finally we have

``````// expr 1
flatmap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))

// expr 2
flatMap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))
``````

and the result follows by the associativity of addition.

## Toy example

We can use a toy example to exercise the code

``````val term1 = Add(Con(1), Add(Con(2), Con(3)))

val res1 = eval(term1)
println(res1) // 6

val res2 = eval(term2)
println(res2) // 6
``````

Next in the paper the operations `map` and `join` are introduced. Both can be implemented in terms of `flatMap` and `unit`. We update our Monad definition accordingly:

``````trait Monad[F[_]] {
def unit[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]

def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => unit(f(a)))

def join[A](ffa: F[F[A]]): F[A] =
flatMap(ffa)(fa => fa)
}
``````

An alternative to the previous definition is to use `unit`, `join`, and `map` as the basic operators.

``````// `unit`, `map`, and `join` as primitives.