Montag, 22. August 2011

A Taste of Dependent Types in Scala

One thing starts to cross my way very frequently is dependent typing. While I'm not an expert on this topic I found some very interesting things in various papers that I would like to be able to do in Scala. Dependent types help to verify that your program does what it's supposed to do by proving theorems.
Since Scala is not dependently typed we need to become a little tricky, but let's give it a try.

Think about the type signature of the following method called zap (which stands for zip-apply):


def zap[A, B](l: List[A], fs: List[A => B]): List[B]


By looking at the signature it should become obvious what the method does. You have a list l that provides elements of type A and a list fs that provides a function from A to B for each element in l. Lets look at the implementation:


def zap[A, B](l: List[A], fs: List[A => B]): List[B] =
(l, fs) match {
case (_, Nil) | (Nil, _) => Nil
case (a :: as, b :: bs) => b(a) :: zap(as, bs)
}


That's a pretty straight forward example, but there is a problem with this code. What if we do provide lists with different length? In one case we would not execute some functions in the other case we would lose data. So what can we do? We could throw an exception but that's already to late, our program is already in a state where things don't make sense anymore, we have to avoid this situation at all and the best thing to achieve this is that such code does not compile at all.

In a dependently typed language we could create a list that encodes it's length within the type of the list itself and write zap in a way so that it would only accept lists of the same length. But since we are not in a dependently typed language we cannot do that, this concludes this blogpost, see you in another episode… ok, just kidding. We actually CAN do that in Scala. Here is how:

The first thing we need is: natural numbers in the type system to encode the length of our list. That is easy:


sealed trait Nat
sealed trait Z extends Nat
sealed trait S[A <: Nat] extends Nat


Z is our zero here and S[A <: Nat] is the successor of the natural number A, so we can create the natural numbers recursively: Z == 0, S[Z] == 1, S[S[Z]] == 2 etc.

Now let's start with the fun part, the list itself.

The first thing we need is the trait for our NList.


trait NList[A <: Nat, +B] {
type Length = A
def zap[C](l: NList[Length, B => C]): NList[Length, C]
}


We can now see that the type of zap enforces that the both lists need to have the same length A. We can now create our first case where we are zapping two lists of length 0 (or Z). The implementation is rather trivial.


case object NNil extends NList[Z, Nothing] {
def zap[A](l: NList[Length, Nothing => A]): NNil.type =
NNil
}


The next case is not as simple (if someone has a better idea of how to implement this, I would love to see it. Please drop a comment).


case class NCons[A <: Nat, B](head: B, tail: NList[A, B])
extends NList[S[A], B] {
def zap[C](l: NList[Length, B => C]): NList[Length, C] =
l match {
case NCons(f, fs) =>
NCons(f(head), tail zap fs)
}
}


We can see that our NCons is actually a NList that is one element longer
than it's tail which has length A, the pattern match inside our zap method only has to handle the cons case since the type signature does not allow anything else.

Now let's put that little bugger to the test:

First lets define some NLists


scala> val l1 = NCons(1, NCons(2, NCons(3, NNil)))
scala> val l2 = NCons(1, NCons(2, NNil))
scala> val f = (_: Int) + 1
scala> val f1 = NCons(f, NCons(f, NCons(f, NNil)))
scala> val f2 = NCons(f, NCons(f, NNil))


Zapping lists with the same length:


scala> l1 zap f1
res1: nlist.NList[l1.Length,Int] = NCons(2,NCons(3,NCons(4,NNil)))

scala> l2 zap f2
res2: nlist.NList[l2.Length,Int] = NCons(2,NCons(3,NNil))


This works just fine, now lets see what happens if we zap two lists with a different length:


scala> l1 zap f2
:14: error: type mismatch;
found : nlist.NCons[nlist.S[nlist.Z],Int => Int]
required: nlist.NList[l1.Length,Int => ?]
l1 zap f2


Exactly what we wanted! This code does not even compile since it doesn't make any sense.
This is a very basic example of what we can do with the Scala type system. Stay tuned for more!

If you would like to play around with the code check out my gist

4 Kommentare:

Unknown hat gesagt…

Since I learned about dependent types I was thinking that at least we can encode the array length in Scala's type system. Nice to see it working :)
Thanks

raichoo hat gesagt…

I think there is a lot more possible with the Scala type system, we just need to push the envelope. ;)

Sanjoy Das hat gesagt…

Interesting.

I used to do similar stuff too (http://playingwithpointers.com/archives/756 and http://playingwithpointers.com/archives/716) but have since then moved to languages with real dependent types. ;)

raichoo hat gesagt…

Some here. What's your weapon of choice. I'm tinkering with Idris, Agda and Coq. :)