According to wikipedia the state monad “allows a programmer to attach state information of any type to a calculation. Given any value type, the corresponding type in the state monad is a function which accepts a state, then outputs a new state (of type s) along with a return value (of type t). This is similar to an environment monad, except that it also returns a new state, and thus allows modeling a mutable environment”.

In this post we’ll implement a state monad from scratch using scala. We’ll start by replicating the behavior of logging we’ve developed in the writer monad post. The logging example illustrates the wikipedia definition where we model a mutable environment, in this case, a logger. Next we’ll develop a more traditional state monad example simulating setters and getters from imperative programming. Finally, we’ll develop two examples found in the literature.

The source code can be found here.

The State Monad

Let’s start by defining the State class.

case class State[S, A](runS: S => (A, S))

The state class abstracts a function. The function receives a type S and returns a tuple (A, S). The type A represents a value and the type S the next state of S.

We’ll now develop a monad instance of the State class. The mechanics are similar to the Writer monad we developed in the writer monad post.

implicit def monad[S]: Monad[({ type l[x] = State[S, x] })#l] = new Monad[({ type l[x] = State[S, x]})#l] {
  def unit[A](a: A): State[S, A] = State(s => (a, s))
  def flatMap[A, B](st: State[S, A])(f: A => State[S, B]): State[S, B] = State(s => {
    val (a, s2) = st.runS(s)
    val st2 = f(a)
    st2.runS(s2)
  })
}

Since State has two generic types we pin the state type, represented by S, while letting the methods map and flatMap apply the conversion from an initial generic type A to a different type B. The functor method map is implemented in the Monad trait. Also, notice that unit lifts the type A leaving S untouched.

The implementation of flatMap can be challenging to follow at first sight. flatMap creates a new State by first executing the State given, retrieving the value A and the next state S. Then, the value A is fed into the function f which yields a new State with a value B. Finally, the State[S, B] is fed with the latest seen state.

Let’s test if we can map over a state instance and start building an insight on how State works.

import State._
import MonadOps._

val s1 = State[String, Int](_ => (10, "I'll start with 10."))
  .map { i => i * 2 }

println(s1.runS("")) // (20,I'll start with 10.)

What about flatMap?

val s2 = State[String, String](_ => ("Hello", "HelloOnce"))
  .flatMap { hello =>
    State[String, String](s => (hello * 2, s ++ "HelloTwice"))
  }

println(s2.runS("")) // (HelloHello,HelloOnceHelloTwice)

Logging example

We’ll use the state monad to replicate the Writer Monad we developed in a previous post. We’ll keep transforming a value while keeping a log describing the value transformations.

Our initial examples showed that we can map and flatMap over instances of the state class. Instantiating a State class is cumbersome. Let’s introduce a helper method to make it easier.

We need a function that receives a value and a log message. We can make the logging value more generic requiring that a monoid instance of the log type exists and thus using the monoid instance to stitch logs together.

def log[A, L : Monoid](a: A, log: L): State[L, A] =
  State(s => (a, implicitly[Monoid[L]].append(s, log)))

Before we proceed we’ll (re-)introduce the implicit class MonadOps2 that along with the implicit monad method will allow us to leverage for-comprehensions.

// for comprehensions for (*, *) -> *
implicit class MonadOps2[A, W, F[_, _]](fa: F[A, W])(implicit M: Monad[({ type l[x] = F[x, W] })#l]) {
  def map[B](f: A => B): F[B, W] = M.map(fa)(f)
  def flatMap[B](f: A => F[B, W]): F[B, W] = M.flatMap(fa)(f)
}

Let’s now replicate the log examples we’ve used in the writer monad post but using the state monad.

val s3: State[String, Int] = for {
  x      <- log(20, "I'll start with 20. ")
  isEven <- log(x % 2 == 0, "Is it even? ")
} yield if (isEven) x + 1 else x


println(s3.runS("")) // (21,I'll start with 20. Is it even? )

While the monad primitives unit, and flatMap, allow us to work around the value of the State class it would be useful to have an api to control state values. That’s the job of the get and set method:

object State {
  def get[S]: State[S, S] = State(s => (s, s))
  def set[S](s: S): State[S, Unit] = State(_ => ((), s))
}

Let’s give it a try

val s4: State[String, Int] = for {
  x      <- log(20, "I'll start with 20. ")
  isEven <- log(x % 2 == 0, "Is it even? ")
  msg = if (isEven) s"Let's make $x odd. " else s"$x is already odd. "
  log    <- get[String]
  _      <- set(log ++ msg)
} yield if (isEven) x + 1 else x

println(s4.runS("")) // (21,I'll start with 20. Is it even? Let's make 20 odd. )

Notice that the setter had to take care of appending the log. We also had to first read the state before “mutating” it. That seems like a common use case and we can enrich the api for such cases. Let’s introduce an appendLog function.

def appendLog(log: String): State[String, Unit] =
  get[String].flatMap(s => set(s ++ log))

We can now rewrite our previous example replacing the get and set operations with appendLog.

val s5: State[String, Int] = for {
  x      <- log(20, "I'll start with 20. ")
  isEven <- log(x % 2 == 0, "Is it even? ")
  msg = if (isEven) s"Let's make $x odd. " else s"$x is already odd. "
  _      <- appendLog(msg)
} yield if (isEven) x + 1 else x

println(s5.runS("")) // (21,I'll start with 20. Is it even? Let's make 20 odd. )

Finally, we can also add a helper method to create state instances

object State {
  def state[S, A](a: A): State[S, A] = monad[S].unit(a)
}

Let’s give it a try

val s6: State[String, Int] = for {
  x      <- state[String, Int](20)
  isEven <- state[String, Boolean](x % 2 == 0)
  odd    <- state[String, Int](if (isEven) x + 1 else x)
  _      <- set(s"We had $x, it was${if (isEven) "" else " not"} even, so we have $odd")
} yield odd

println(s6.runS("")) // (21,We had 20, it was even, so we have 21)

Extra examples

The following example was taken from the book Functional Programming in Scala with a few minor modifications to use the implementation we’ve developed in this post.

The state monad is used to zip a list with the elements indexes.

def zipWithIndex[A](xs: List[A]): List[(A, Int)] = {
  val M = implicitly[Monad[({ type l[x] = State[Int, x] })#l]]
  xs.foldLeft(M.unit(List[(A, Int)]())) ((acc, a) => for {
    xs <- acc
    n  <- get[Int]
    _  <- set(n + 1)
  } yield (a, n) :: xs).runS(0)._1.reverse
}

val s7 = zipWithIndex(List('A', 'B', 'C'))
println(s7) // List((A,0), (B,1), (C,2))

The following example was taken from the ebook From Simple IO to Monad Transformers.

It implements a stack in terms of the state monad.

val r8: State[List[Char], Unit] = for {
  c <- pop[Char]
  _ <- push[Char]('a')
  _ <- push[Char](c)
} yield ()

val r9 = r8.runS(List('c', 't'))._2.mkString("")

println(r9) // cat

Here’s the definition of push and pop.

object State {
  def push[A](a: A): State[List[A], Unit] = State(s => ((), a :: s))
  def pop[A]: State[List[A], A] = State { xs => (xs.head, xs.tail) } // unsafe
}