Monday, July 25, 2011

A small look at CPS into Scala and Clojure

A small and useless one this evening about attempts to write CPS (Continuation Passing Style) algorithms in both Scala and Clojure. Purist are going to shoot me because everybody knows that the JVM is not optimized at all for full tail recursion. So what's the purpose with CPS on JVM languages after all?
It is a matter of trying to train the object oriented eye to identify other patterns of programming code that are quite common in functional programming domains. I would like to take the opportunity to report the publication of Dean Wampler book: Functional Programming for Java Developers. I downloaded the pdf last saturday and started reading it. Quite easy to read the book unifies very cleverly both the worlds of object oriented programming in Java and of functional programming paradigm allowing the curious Java developer to extend its experience through the application of dense and essential exercises and code samples. This book deserves to become a must.

Digression achieved, let's come back to CPS. I discovered once CPS in Michael Fogus and Chris Houser book, the joy of Clojure , and to my real shame forgot about it. After all, you remember, the JVM is not fully optimized blah, blah, blah.

Then I saw a video of Chris League explaining delimited continuations and monadic programming to the New York Scala Enthusiasts Meetup. His introduction starts with continuations. I won't digg into the content of the presentation -as I am still exploring it - but I wanted to reproduce the begining of it. The example is nice.

Starting with a Clojure, the most obvious example is with the processing of the factorial of a number.
The basic "mundane" approach (expression borrowed from my favourite Clojure book) leads to the first sample:

(defn factorial [number]
    (if(zero? number)
        1
 (* number (factorial (dec number)))))

A typical recursive call, with a mandatory breaking condition:

(zero? number)

You can download Clojure form there, and execute the code in the REPL:

user=> (factorial 5)
120

The result is instant.

Good. Depending on your favourite environment sizing, increase the number value. In the context of execution of my poor memory-sized machine it breaks around 10000:

user=> (factorial 10000)
java.lang.StackOverflowError (NO_SOURCE_FILE:0)

Hey, remember the talk about the JVM not being optimized for full tail recursion ? Well we are there. We stack recursively so many frames that we outage the maximum number of stacks.

Clojure offers a workaround, using the special form recur. Combined with the accumulator technique. Basically we are storing the partial on going result, as we are invoking recursively the method, we can easily write:

(defn factorial [number]
    (letfn [(factorial-acc [n accumulator]
        (if(zero? n) 
            accumulator
            (recur (dec n) (* n accumulator))))]
    (factorial-acc number 1)))

Grossly, recur, returns the control back to the function call acting like a loop or a while form. (dec n) and (* n accumulator) are evaluated before the function call. The parameter accumulator in essence "accumulates" the result to come. What we are storing is a data.

Go ahead try it, far beyond 10000 if you want. It works fine. I like recur. What we accumulated here is the data in process .

But what if we accumulated the work to be done at the very end? Something like wrapping the work to be done in order to unwrap itm while meeting the breaking condition. We can do that explicitly taking control back after the execution of each recursive call. In summary, what if we were writing something like this:

(defn factorial [number]
    (letfn [(factorial-cps [n continuation]
        (if(zero? n) 
            (continuation 1)
            (factorial-cps (dec n) (fn [value] (continuation (* n value))))))]
    (factorial-cps number (fn[n] n))))  

What we are passing there to the recursive invocation, is the next operation to apply after the function execution. We transfer a mean of control of the result of this processing, through the application of a continuation. Something similar to a call back.
At some step, we know n, we want to process the factorial of n-1, and the continuation is what happens after processing factorial(n-1) in order to get factorial(n).
The continuation is :

(fn [value]  (* n value))

the value being factorial(n - 1)

We cannot be more functional, more declarative than that.

We chain the recursive invocation, continuation after continuation:

(fn [value] (continuation (* n value)))

What we want to get back is the value of the factorial processing itself, so the very first continuation will be the id function:

(fn[n] n)

The final order execution of the whole work we have been pushing ahead will be:

((fn[n] n)((fn [value] (* n value)) ((fn [value] (* (dec n) value)) .....)

some natural order indeed. The main idea is that you clearly express your intent of taking back control in code in an explicit declarative manner. This case is not handled by the recur call so naturally, and at last while executing all the pushed functions, your stack will blow.

If your algorithm can guarantee a limited use of the stack size you may have a nice declarative expression. This is striking in Scala and the resulting code patterns might seduce DSL lovers.

Still with me ? :) Take your time. No rush.

I reproduced Chris League experiment. The intent is to process the factorial of a number so here we go:

import org.junit.Test
import org.scalatest.junit.{ShouldMatchersForJUnit, JUnitSuite}
import MathBox._

final class MathBoxTest extends JUnitSuite with ShouldMatchersForJUnit{
  def id: (Int) => Int = {
    (input: Int) => input
  }


  @Test
  def factorial_WithSmallNumber_ShouldHaveAMatch() {
    factorial(3, id).should(be(6))
    factorial(5, id).should(be(120))
  }

  @Test
  def decrement_WithIdContinuation_ShouldBeDecrementInput() {
    decrement(17, id).should(be(16))
  }

  @Test
  def times_WithIdContinuation_MultiplyInputs() {
    times(17, 11, id).should(be(187))
  }
}

Self explanatory tests :).

I found the result quite elegant:

object MathBox {

  def factorial[A](value: Int, continuation: Int => A): A  = {
    lowerThan(value, 2, continuation(1),
      decrement(value, (valueMinusOne: Int) =>
        factorial(valueMinusOne, (result: Int) => 
          times(value, result, continuation))))
  }

  def lowerThan[A](first: Int, second: Int, verified: => A, invalid: => A): A = {
    if(first < second) verified else invalid
  }

  def decrement[A](value: Int,  continuation: Int => A): A = {
    continuation(value - 1)
  }

  def increment[A](value: Int,  continuation: Int => A): A = {
    continuation(value + 1)
  }


  def times[A] (x: Int, y: Int , continuation: Int => A): A = {
    continuation(x * y)
  }
}

The factorial procedure content is self explanatory, because expressing your intent to take control of the process, you write then your algorithm in a very fluent language.

  • The breaking condition is lowerThan comparison
  • you recur on a factorial...
  • ...after decrementing the number value
  • ...and you will apply a multiply operation on each iteration

So literate...

def factorial[A](value: Int, continuation: Int => A): A  = {
    lowerThan(value, 2, continuation(1),
      decrement(value, (valueMinusOne: Int) =>
        factorial(valueMinusOne, (result: Int) => times(value, result, continuation))))
  }

Not yet satisfied I decided to make procedures in charge of creating lists that would contain all the numbers up to a specified natural number, or down from a natural number. The tests naturally are

import org.junit.Test
import org.scalatest.junit.{ShouldMatchersForJUnit, JUnitSuite}

class TestListBuilder extends JUnitSuite with ShouldMatchersForJUnit{

  @Test
  def downTo_WithRange10_ShouldContain10Elements() {
    ListBuilder.listDownFrom(10, (x => x)).size.should(be(10))
    ListBuilder.listDownFrom(10, println)
  }
  @Test
  def upTo_WithRange10_ShouldContain10Elements() {
    ListBuilder.listUpTo(10, x=> x).size.should(be(10))
    ListBuilder.listUpTo(10, println)
  }

}

Note that you can pass the id function as a continuation if you want to get the result of your stacking recursion, but you can also transfer whatever method you want in order to manipulate the resulting list. Here we trace the method passing println.

Following the same approach I produced the following:

import MathBox._

object ListBuilder {
  def prepend[A](value: Int, xs: List[Int], continue: (List[Int]) => A): A = {
    continue(value :: xs)
  }

  def listDownFrom[A](range: Int, continue: (List[Int]) => A): A = {
    zero(range, continue(Nil),
      decrement(range, (rangeMinusOne: Int) =>
        listDownFrom(rangeMinusOne, (xs: List[Int]) =>
          prepend(range, xs, continue))))
  }


  def listUpTo[A](maximum: Int, continue: (List[Int]) => A): A = {

    def buildListUpTo[A](maximum: Int, index: Int, continue: (List[Int]) => A): A = {
      lowerThan(maximum, index, continue(Nil),
        increment(index, (rangePlusOne: Int) =>
          buildListUpTo(maximum, rangePlusOne, (xs: List[Int]) =>
            prepend(index, xs, continue))))
    }

    buildListUpTo(maximum, 1, continue)
  }



  def same[A](cursor: Int, value: Int, stop: => A, continue: => A): A = {
    if (value == cursor) stop else continue
  }


  def zero[A](value: Int, stop: => A, continue: => A): A = {
    if (value == 0) stop else continue

  }
}

The lowerThan and zero methods are the accept functions necessary to apply the decision to break the recursion.

Note that in order to produce the ascending list I have to use a nested helping method while accumulating both the limit to reach to and the natural number to be pre-pended to the list.

Too short. But I am tired and must go to bed.

Be seeing you ! :)

3 comments:

osa1 said...

In your first code snippet, you get StackoverflowError even if JVM has tail-call optimization, your function doesn't have tail-call(recursive call of the function is not the last expression in the function).

Globulon said...

Trying
user=> (factorial 1000N)
or user=> (factorial 10000N)
etc...works nicely
One must use 10000N instead of 10000 in Clojure 1.3.0

The tail recursion being in

(factorial-acc [n accumulator]
(if(zero? n)
accumulator
(recur (dec n) (* n accumulator))))

as

(factorial-acc number 1)

calls first factorial-acc and recur is the last call in factorial-acc calling herself.

I may be wrong. But seems to work fine :)

Globulon said...

Moreover a snippet of the generated byte code gives us a hint on the optimization provided by clojure:

public java.lang.Object invoke(java.lang.Object, java.lang.Object);
Code:
Stack=4, Locals=3, Args_size=3
0: aload_1
1: invokestatic #44; //Method clojure/lang/Numbers.isZero:(Ljava/lang/Object;)Z
4: ifeq 14
7: aload_2
8: aconst_null
9: astore_2
10: goto 32
13: pop
14: aload_1
15: invokestatic #47; //Method clojure/lang/Numbers.dec:(Ljava/lang/Object;)Ljava/lang/Number;
18: aload_1
19: aconst_null
20: astore_1
21: aload_2
22: aconst_null
23: astore_2
24: invokestatic #51; //Method clojure/lang/Numbers.multiply:(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Number;
27: astore_2
28: astore_1
29: goto 0
32: areturn

where you recognize the loop on line 29: as goto 0 leading to the start of the procedure, and the rupture condition on line 10: as the goto 32 sending us to line 32: up to the return opcode.

Post a Comment