# Monads for functional programming in scala part 2

In this post we use scala to define the monad laws based on the paper Monads for functional programming.

The monad laws are described in terms of `left unit`

, `righ unit`

and `associativity`

.

The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.

- Left unit

```
unit a ⋆ λb.n = n[a/b]
```

```
flatMap(unit(a'))(b' => n) = n[a'/b']
```

We change the scala definition to use `a'`

and `b'`

. It makes it easier to
follow the derivation of the addition example introduced later in the text.

- Right unit

```
m ⋆ λa.unit a = m
```

```
flatMap(m)(a => unit(a)) = m
```

- Associativity

```
m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o
```

```
flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)
```

## Associativity of addition

We prove that addition is associative using these laws.

First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.

```
sealed trait Term
case class Con(i: Int) extends Term
case class Add(t1: Term, t2: Term) extends Term
type Id[A] = A
implicit val idMonad = new Monad[Id] {
def unit[A](a: A): Id[A] = a
def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
}
// eval :: Term -> M Int
def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
// unit a
case Con(a) => M.unit(a)
// eval t ⋆ λa.eval u ⋆ λb.unit(a + b)
case Add(t, u) => M.flatMap(eval(t)(M))(a =>
M.flatMap(eval(u)(M))(b => M.unit(a + b)))
}
```

Next we show that `Add(t, Add(u, v))`

and `Add(Add(t, u), v)`

compute the same result.

We use the left unit and associativity laws to simplify the first expression.

We also remove the `M`

references in order to make the syntax clearer.

```
eval(Add(t, Add(u, v)))
flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x)))
flatMap(eval(t))(a =>
flatMap(flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x)))
// left unit
flatmap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(a + (b + c)))))
```

The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:

```
n = unit(a + x)
a' = (b + c)
b' = x
n [a'/b'] = unit(a + (b + c))
```

The same approach is used in the second expression:

```
eval(Add(Add(t, u), v))
flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c)))
flatMap(flatMap(eval(t))(a =>
flatMap(eval(u))(b => unit(a + b))))(x =>
flatMap(eval(v))(c => unit(x + c)))
// left unit
flatMap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit((a + b) + c))))
```

Again, we show the left unit substitution:

```
n = flatMap(eval(v))(c => unit(x + c))
a' = (a + b)
b' = x
n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))
```

Finally we have

```
// expr 1
flatmap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))
// expr 2
flatMap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))
```

and the result follows by the associativity of addition.

## Toy example

We can use a toy example to exercise the code

```
val term1 = Add(Con(1), Add(Con(2), Con(3)))
val term2 = Add(Add(Con(1), Con(2)), Con(3))
val res1 = eval(term1)
println(res1) // 6
val res2 = eval(term2)
println(res2) // 6
```

## Monad definition revisited

Next in the paper the operations `map`

and `join`

are introduced. Both can be
implemented in terms of `flatMap`

and `unit`

. We update our Monad definition
accordingly:

```
trait Monad[F[_]] {
def unit[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => unit(f(a)))
def join[A](ffa: F[F[A]]): F[A] =
flatMap(ffa)(fa => fa)
}
```

An alternative to the previous definition is to use `unit`

, `join`

, and `map`

as
the basic operators.

```
// `unit`, `map`, and `join` as primitives.
trait Monad2[F[_]] {
def unit[A](a: A): F[A]
def map[A, B](fa: F[A])(f: A => B): F[B]
def join[A](ffa: F[F[A]]): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] =
join(map(fa)(a => f(a)))
}
```