DEV Community

Brian Berns
Brian Berns

Posted on • Edited on

The state monad in F#

This is a primer on implementing stateful computations in F# without violating functional purity. Wait, side effects? Is that really possible in a functional language? Yes, but not in the way you might be used to from imperative programming.

Stacks

Let's define a basic immutable stack type based on lists:

type Stack<'t> = private Stack of List<'t>

module Stack =

    let ofList list = Stack list

    /// 'a -> Stack<'a> -> Stack<'a>
    let push item (Stack items) =
        Stack (item :: items)

    /// Stack<'a> -> 'a * Stack<'a>
    let pop = function
        | Stack (head :: tail) -> head, Stack tail
        | Stack [] -> failwith "Empty stack"
Enter fullscreen mode Exit fullscreen mode

Note that push and pop are pure functions: they both return a new, updated stack, rather than modifying the given stack directly. Using a stack this way can get a little tedious, because we have to explicitly track the state of the stack with a different variable at every step:

let stack = Stack.ofList [1; 2]
let item, stack' = stack |> Stack.pop
let stack'' = stack' |> Stack.push 3

/// Output: 1, Stack [3; 2]
printfn "%A, %A" item stack''
Enter fullscreen mode Exit fullscreen mode

Those stack variables are exhausting and distract from the purpose of the code. Fortunately, it's possible to eliminate them entirely, while still honoring the immutability of Stack.

Stateful computations

The general pattern here is that a stateful operation takes the current state as input, performs some action on it, and returns a result along with a new, updated state: 'state -> 'result * 'state1. Let's put that signature into a type that represents a stateful computation:

/// A stateful computation.
type Stateful<'state, 'result> =
    Stateful of ('state -> 'result * 'state)
Enter fullscreen mode Exit fullscreen mode

Note that this type represents a computation, but doesn't actually execute it. In order to run a computation, we need to invoke the wrapped function with a given state:

/// 'state -> Stateful<'state, 'result> -> ('result * 'state)
let run state (Stateful f) =
    f state
Enter fullscreen mode Exit fullscreen mode

The state monad

Our Stateful type is a monad, which means that it supports return and bind functions. The first one creates a stateful computation that simply returns a given value without altering the current state:

/// 'result -> Stateful<'state, 'result>
let ret result =
    Stateful (fun state -> (result, state))
Enter fullscreen mode Exit fullscreen mode

Binding two stateful computations together is more complex, but quite elegant:

/// ('a -> Stateful<'state, 'b>) -> Stateful<'state, 'a> -> Stateful<'state, 'b>
let bind binder stateful =
    Stateful (fun state ->
        let result, state' = stateful |> run state
        binder result |> run state')
Enter fullscreen mode Exit fullscreen mode

As usual, bind is the most important part of the state monad, so let's make sure we understand what's going on. First, take a look at the signature: the two computations we're binding must work with the same type of state (e.g. Stack). The only thing that can vary between them is their result types. So, how does the combined computation work? It starts by running the first computation on the incoming state, producing a result and a new state (state'). It then passes that result to the given binder function, giving us the second computation. Lastly, it runs that computation using the updated state. Again, keep in mind that neither of the bound computations are actually executed within bind. We're simply defining a new computation that will run them both in sequence when called upon to do so.

With that out of the way, it's easy to define a builder:

type StatefulBuilder() =
    let (>>=) stateful binder = Stateful.bind binder stateful
    member __.Return(result) = Stateful.ret result
    member __.ReturnFrom(stateful) = stateful
    member __.Bind(stateful, binder) = stateful >>= binder
    member __.Zero() = Stateful.ret ()
    member __.Combine(statefulA, statefulB) =
        statefulA >>= (fun _ -> statefulB)
    member __.Delay(f) = f ()

let state = StatefulBuilder()
Enter fullscreen mode Exit fullscreen mode

Stack computations

Let's create stateful computations for our stack operations, push and pop. pop is easy because its signature is exactly what we need:

/// Stateful<Stack<'a>, 'a>
let popC = Stateful Stack.pop
Enter fullscreen mode Exit fullscreen mode

We call this popC to emphasize that it defines a computation. push is nearly as easy - we just need to explicitly return () as a result:

/// 'a -> Stateful<Stack<'a>, unit>
let pushC item =
    Stateful (fun stack ->
        (), Stack.push item stack)
Enter fullscreen mode Exit fullscreen mode

Note that pushC 2 is a computation that pushes 2 on a stack, while pushC 9 is a different computation that pushes 9 on a stack. The arguments to push are baked directly into the computations, so running a computation takes no input other than the state.

We're ready to define a complex stateful computation using stacks:

/// Stateful<Stack<int>, int>
let comp =
    state {
        let! a = popC
        if a = 5 then
            do! pushC 7
        else
            do! pushC 3
            do! pushC 8
        return a
    }
Enter fullscreen mode Exit fullscreen mode

This is a computation that works on state of type Stack<int> and returns an int. Note that the computation doesn't explicitly refer to the stack at any point - the state monad takes care of managing it for us! Instead, we've just combined a bunch of low-level computations into a single high-level computation, step by step.

Let's run this computation to see if it works:

// Output: (9, Stack [8; 3; 0; 2; 1; 0])
let stack = [9; 0; 2; 1; 0] |> Stack.ofList
printfn "%A" (Stateful.run stack comp)

// Output: (5, Stack [7; 1])
let stack = [5; 1] |> Stack.ofList
printfn "%A" (Stateful.run stack comp)
Enter fullscreen mode Exit fullscreen mode

Success!

For more information on the state monad, the following articles are helpful (both use Haskell):


  1. To be clear, this is the same as 'state -> ('result * 'state), not ('state -> 'result) * 'state

Top comments (0)