 # daniberg.com

## Monads for functional programming in scala part 2

In this post we use scala to define the monad laws based on the paper Monads for functional programming.

The monad laws are described in terms of `left unit`, `righ unit` and `associativity`.

The haskell / lambda calculus syntax used in the paper clearly defines the laws and so we repeat them here followed by a pseudo scala definition.

#### 1. Left unit

```unit a ⋆ λb.n = n[a/b]
```
```flatMap(unit(a'))(b' => n) = n[a'/b']
```

We change the scala definition to use `a'` and `b'`. It makes it easier to follow the derivation of the addition example introduced later in the text.

#### 2. Right unit

```m ⋆ λa.unit a = m
```
```flatMap(m)(a => unit(a)) = m
```

#### 3. Associativity

```m ⋆ (λa.n ⋆ λb.o) = (m ⋆ λa.n) ⋆ λb.o
```
```flatMap(m)(a => flatMap(n)(b => o)) = flatMap(flatMap(m)(a => n))(b => o)
```

### Associativity of addition

We prove that addition is associative using these laws.

First we define a language to describe additions and a slightly modified version of the evaluator introduced in the first part of this post.

```sealed trait Term
case class Con(i: Int) extends Term
case class Add(t1: Term, t2: Term) extends Term

type Id[A] = A

implicit val idMonad = new Monad[Id] {
def unit[A](a: A): Id[A] = a
def flatMap[A, B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
}

// eval :: Term -> M Int
def eval[F[_]](term: Term)(implicit M: Monad[F]): F[Int] = term match {
// unit a
case Con(a) => M.unit(a)
// eval t ⋆ λa.eval u ⋆ λb.unit(a + b)
case Add(t, u) => M.flatMap(eval(t)(M))(a =>
M.flatMap(eval(u)(M))(b => M.unit(a + b)))
}
```

Next we show that `Add(t, Add(u, v))` and `Add(Add(t, u), v)` compute the same result.

We use the left unit and associativity laws to simplify the first expression.

We also remove the `M` references in order to make the syntax clearer.

```eval(Add(t, Add(u, v)))

flatMap(eval(t))(a => flatMap(eval(Add(u, v)))(x => unit(a + x)))

flatMap(eval(t))(a =>
flatMap(flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(b + c))))(x => unit(a + x)))

// left unit
flatmap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit(a + (b + c)))))
```

The substitutions can be tricky to follow on a first read. The left unit substitution is explained in more details:

```n = unit(a + x)
a' = (b + c)
b' = x
n [a'/b'] = unit(a + (b + c))
```

The same approach is used in the second expression

```eval(Add(Add(t, u), v))

flatMap(eval(Add(t, u)))(x => flatMap(eval(v))(c => unit(x + c)))

flatMap(flatMap(eval(t))(a =>
flatMap(eval(u))(b => unit(a + b))))(x =>
flatMap(eval(v))(c => unit(x + c)))

// left unit
flatMap(eval(t))(a =>
flatMap(eval(u))(b =>
flatMap(eval(v))(c => unit((a + b) + c))))
```

Again, we show the left unit substitution

```n = flatMap(eval(v))(c => unit(x + c))
a' = (a + b)
b' = x
n[a'/b'] = flatMap(eval(v))(c => unit((a + b) + c))
```

Finally we have

```// expr 1
flatmap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit(a + (b + c)))))

// expr 2
flatMap(eval(t))(a =>
flatMap(eval(u))(b => flatMap(eval(v))(c => unit((a + b) + c))))
```

and the result follows by the associativity of addition.

### Toy example

We can use a toy example to exercise the code

```val term1 = Add(Con(1), Add(Con(2), Con(3)))

val res1 = eval(term1)
println(res1) // 6

val res2 = eval(term2)
println(res2) // 6
```

### Monad definition revisited

Next in the paper the operations `map` and `join` are introduced. Both can be implemented in terms of `flatMap` and `unit`. We update our Monad definition accordingly

```trait Monad[F[_]] {
def unit[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]

def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => unit(f(a)))

def join[A](ffa: F[F[A]]): F[A] =
flatMap(ffa)(fa => fa)
}
```

An alternative to the previous definition is to use `unit`, `join`, and `map` as the basic operators.

```// `unit`, `map`, and `join` as primitives.