Monday, January 16, 2012

Changing my state of mind with a Monad in Scala

Houdy.

While looking for more real world example in order to complete a previous blog entry, I found myself struggling with the State Monad in order to solve what I supposed to be a typical State Monad problem.
My rambling to solve this specific problem did not succeed while in the mean time I successfully reproduced the canonical sample of a stack manipulation extracted from Learn You a Haskell (LYAH). Although I am not satisfied with the result I would like to expose this "kata", and will ask for your feedback in order to reach to the expected goal.
I specially thank Nilanjan Raychaudhuri - the author of Scala in action - for his precious help. Reading chapter 10 from his book confirmed I was working into the right direction. 

Reproducing the LYAH example remains a fruitful exercise in that sense that it constrains you to use Scala idioms (typing, self type annotation etc.) and forces you to think about some of the inner mechanics of the for comprehensions

In order to expose the interest of reproducing state management in functional programming languages, Miran Lipovaca presents a three coins problem simulating the extraction of results from tossing a coin, and a stack manipulation problem. From the point of view of imperative languages, the random generator internals or the stack internal would be easily modifiable, mutable objects allowing to generate new numbers or alter the stack state. 
In pure functional language, we manipulate immutable data. We have to create a new value object each time the equivalent of a state change occurs. But what if we could separate the flow of data from the side effect manipulation of the change of state. 
And that, is specifically our purpose, embedding a change of state in a dedicated instance. The secret lays into the abstract representation of this change of state as a function:

def apply(state: S): (T, S)

where S references the type of the state to be changed, and T is the type of the result of the stateful computations. The whole class hosts the apply function (so is applicable by itself in Scala), and impersonates the context that contains the state management. You apply the context in order to get your result value:

contextInstance(previousState) = (result, newState)

For the same price, you get the altered state. In the case of a stack, the state is the stack content. The provided manipulation contexts will be class instances implementing context templates for stack manipulation like pop and push. We will represent a stack state as a List of items of type A:

List[A]

Consequently, if we choose to name our state context StateMonad, the pop and push operations can be gathered in a Stack scope definition like the following:

object Stack {
  def push[A](x: A) = new StateMonad[Unit, List[A]] {
    def apply(state: List[A]) = ((), x :: state)
  }

  def pop[A] = new StateMonad[Option[A], List[A]] {
    def apply(state: List[A]) =
      state match {
        case x :: xs => (Some(x), xs)
        case _ => (None, state)
      }
  }
}

taking a leap of faith regarding an existing definition of the StateMonad trait. In the mean time we have acknowledged that our state Monad trait definition will be parameterized as:

trait StateMonad[+T, S]

While pushing data on top of a stack, I expect no result, so I return a () (aka void) instance:

scala> import Stack._
import Stack._

scala> push(5)(List())
res0: (Unit, List[Int]) = ((),List(5))

while the result of a pop context execution may contain an optional item of type A, depending on the size of the previous stack state (no elements at all, or at least one element):

scala> import Stack._
import Stack._

scala> pop(List(1))
res1: (Option[Int], List[Int]) = (Some(1),List())

scala> pop(List())
res2: (Option[Nothing], List[Nothing]) = (None,List())

I believe the case pattern matching in the pop method body, to be self explanatory. Chaining the state modifications, then, can be achieved using both definitions of map and flatMap. The application of the map method is helpful in transforming the result embedded into the context, producing a new state Monad taking into account the expected transformation:

def map[U](f: T => U) = new StateMonad[U, S] 

while defining a flatMap method helps in simplifying the chaining of

flatMap[U](f: T => StateMonad[U,S])

How is so ? Simply as we did last time

scala> import com.promindis.user._
import com.promindis.user._

scala> import Stack._
import Stack._

scala> val result = push(3).flatMap{ _ =>
     |       push(5).flatMap{_ =>
     |         push(7).flatMap{_ =>
     |           push(9).flatMap{_ =>
     |             pop.map{_ => ()}
     |           }
     |         }
     |       }
     |     }
result: java.lang.Object with com.promindis.state.StateMonad[Unit,List[Int]] = com.promindis.state.StateMonad$$anon$1@124e407

scala> result(List())
res2: (Unit, List[Int]) = ((),List(7, 5, 3))

scala>

The benefit of map and flatMap becomes obvious while using more idiomatic Scala expressions like comprehensions that get interpreted as the above lines of codes:

scala> import com.promindis.user._
import com.promindis.user._

scala> import Stack._
import Stack._

scala> val result = for {
     |       _ <- push(3)
     |       _ <- push(5)
     |       _ <- push(7)
     |       _ <- push(9)
     |       _ <- pop
     |     } yield ()
result: java.lang.Object with com.promindis.state.StateMonad[Unit,List[Int]] = com.promindis.state.StateMonad$$anon$1@7a6088

scala> result(List(1))
res3: (Unit, List[Int]) = ((),List(7, 5, 3, 1))

scala>

The full implementation of the StateMonad trait becomes then:

package com.promindis.state

trait StateMonad[+T, S]  {
  owner =>
  def apply(state: S): (T, S)

  def flatMap[U](f: T => StateMonad[U,S]) = new StateMonad[U, S] {
    override def apply(state: S) = {
      val (a, y) =  owner(state)
      f(a)(y)
    }
  }

  def map[U](f: T => U) = new StateMonad[U, S] {
    def apply(state: S) = {
      val (a, y) =  owner(state)
      (f(a), y)
    }
  }
}

object StateMonad {
  def apply[T, S](value: T) = new StateMonad[T, S] {
    def apply(state: S) = (value, state)
  }
}

The map function produces a resulting new container instance in charge of applying the new state transformation on the transformed result from the original container instance. The typed self annotation owner, helps in referencing the original container from the apply method body of the new anonymous StateMonad instance:

owner =>

How do we extract the result from the previous container ? Again, applying the previous container itself:

val (a, y) =  owner(state)

The result of the new anonymous StateMonad container will be

(f(a), y)

The body of the apply method in the container of the StateMonad instance resulting from the flatMap application will lead to
  • the application of the previous container (so to extract the previous result and state),
  • then the application of the transformation function to the result
  • and finally the application of the new StateMonad instance f(a) to the y intermediate state.
We have chained the previous container state change to the state change expected after the f function application. Whole this chaining is itself transparently hosted by a containing monad. The complete stack example can be reproduced:

package com.promindis.user

import com.promindis.state._

object Stack {
  def push[A](x: A) = new StateMonad[Unit, List[A]] {
    def apply(state: List[A]) = ((), x :: state)
  }

  def pop[A] = new StateMonad[Option[A], List[A]] {
    def apply(state: List[A]) =
      state match {
        case x :: xs => (Some(x), xs)
        case _ => (None, state)
      }
  }
}

object UseState {
  import Stack._
  def main(args: Array[String]) {
    val result = for {
      _ <- push(3)
      _ <- push(5)
      _ <- push(7)
      _ <- push(9)
      _ <- pop
    } yield ()

    println(result(List(1))._2)

    val otherResult = push(3).flatMap{ _ =>
      push(5).flatMap{_ =>
        push(7).flatMap{_ =>
          push(9).flatMap{_ =>
            pop.map{_ => ()}
          }
        }
      }
    }

    println(otherResult(List(1))._2)
  }
}


The example works fine as in LYAH, but I am not satisfied with the result for two reasons.
  • The first is that I was not able (still) to link this implementation to my previous Monad definition here.
  • The second point is that I have to reproduce some real life example like Nilanjan Raychaudhuri ones in order to stress these definitions.
All feedback and suggestion will be welcomed as usual.

Until then, I have to do a little haskell, study more the disruptor and practice some katas. 

 Be seeing you !!! :)

0 comments:

Post a Comment