Mike's corner of the web.

Solving the expression problem with union types in Shed

Monday 30 April 2012 22:31

The expression is a tricky problem in many languages that asks: given a set of functions that operate over a set of types, how do we allow both the set of functions and the set of types that those functions operate over be extended without losing type safety? If you're not familiar with the problem, I recommend reading the explanation by the author of Magpie. For our purposes, we'll use an abstract syntax tree for mathematical expressions as our data type. To start, let's have two sorts of node: addition operators and literals.

interface Node {}

def AddNode class(myLeft: Node, myRight: Node) implements Node => {
    public def left fun() => myLeft;
    public def right fun() => myRight;
}

def LiteralNode class(myValue: Double) implements Node => {
    public def value fun() => myValue;
}

(As an aside: due to the design of the language, we can't give the arguments to a class the same name as it's getter, for instance def value fun() => value, since the body of the function would refer to the function rather than the class argument. Prepending each of the arguments with my is a poor solution, and although I have several ideas on how to rectify this, I'm still pondering on the simplest, cleanest solution.)

Suppose we want to implement a function evaluate that evaluates the expression to a single value. Our first attempt at an implementation might look like this:

def evaluate fun(node: Node) : Double =>
    match(node,
        case(AddNode, evaluateAdd),
        case(LiteralNode, evaluateLiteral)
    );

def evaluateAdd fun(add: AddNode) =>
    evaluate(add.left()) + evaluate(add.right());

def evaluateLiteral fun(literal: LiteralNode) =>
    literal.value();

There's one immediate with this solution: it's not type-safe. If somebody adds another implementation of Node, then evaluate no longer covers all possible cases. The solution to this problem is to define a union type:

type StandardNode = AddNode | LiteralNode

and update evaluate by changing the type of its argument:

def evaluate fun(node: StandardNode) : Double =>
    match(node,
        case(AddNode, evaluateAdd),
        case(LiteralNode, evaluateLiteral)
    );

def evaluateAdd fun(add: AddNode) =>
    evaluate(add.left()) + evaluate(add.right());

def evaluateLiteral fun(literal: LiteralNode) =>
    literal.value();

This makes evaluate type-safe, but has had the unintended consequence of making evaluateAdd unsafe: add.left() and add.right() both have the type Node, yet evaluate only accepts the narrower type StandardNode. We fix this by adding type parameters to AddNode:

def AddNode class[T] => (myLeft: T, myRight: T) implements Node => {
    public def left fun() => myLeft;
    public def right fun() => myRight;
}

and modifying the type of the argument of evaluateAdd and updating the value of StandardNode:

def evaluateAdd fun(add: AddNode[StandardNode]) =>
    evaluate(add.left()) + evaluate(add.right());
    
type StandardNode = AddNode[StandardNode] | LiteralNode;

(At this point that the interface Node isn't really necessary any more, although there might be other reasons to keep it around.)

Suppose we now add NegateNode and the associated union type ExtendedNode:

def NegateNode class[T] => (myValue: T) => {
    public def value fun() => myValue;
}

type ExtendedNode =
    AddNode[ExtendedNode] |
    NegateNode[ExtendedNode] |
    LiteralNode;

ExtendedNode cannot reuse the definition of StandardNode since AddNode[ExtendedNode] is a subtype of ExtendedNode but not a subtype of StandardNode. The solution is to introduce another type parameter, this time on StandardNode and ExtendedNode:

type StandardNode[T] = AddNode[T] | LiteralNode;

type ExtendedNode[T] = StandardNode[T] | NegateNode[T];

We can then add the appropriate type parameters to the argument of evaluate:

def evaluate fun(node: StandardNode[StandardNode]) : Double =>
    match(node,
        case(AddNode[StandardNode[StandardNode]], evaluateAdd),
        case(LiteralNode, evaluateLiteral)
    );

But this doesn't work either: we need to specify the type parameter to the second reference to StandardNode, which is StandardNode, which also requires a type parameter... and so on. The solution is to add yet more types that fix the type parameter to themselves:

type StandardNodeF = StandardNode[StandardNodeF];
type ExtendedNodeF = ExtendedNode[ExtendedNodeF];

def evaluate fun(node: StandardNodeF) : Double =>
    match(node,
        case(AddNode[StandardNodeF], evaluateAdd),
        case(LiteralNode, evaluateLiteral)
    );

In order to evaluate an instance of ExtendedNode, we'd need to define the following:

def evaluateExtended fun(node: ExtendedNodeF) : Double =>
    match(node,
        case(AddNode[ExtendedNodeF], evaluateAddExtended),
        case(NegateNode[ExtendedNodeF], evaluateNegate),
        case(LiteralNode, evaluateLiteral)
    );

def evaluateAddExtended fun(add: AddNode[ExtendedNodeF]) =>
    evaluateExtended(add.left()) + evaluateExtended(add.right());
    
def evaluateNegate fun(negate: NegateNode[ExtendedNodeF]) =>
    -evaluateExtended(negate.value());

It seems reasonable to write evaluateNegate, but the definition of evaluateAddExtended seems virtually the same as before. The difference is the type parameter for AddNode, and the function we use to evaluate the sub-nodes. So, we introduce a type parameter and argument to abstract both:

def evaluateAdd fun[T] => fun(evaluator: Function[T, Double]) =>
    fun(add: AddNode[T]) =>
        evaluator(add.left()) + evaluator(add.right());

We can also perform a similar transformation on evaluateNegate and evaluate:

def evaluateNegate fun[T] => fun(evaluator: Function[T, Double]) =>
    fun(negate: NegateNode[T]) =>
        -evaluator(negate.value());

def evaluate fun[T] => fun(evaluator: Function[T, Double]) =>
    fun(node: T) : Double =>
        match(node,
            case(AddNode[StandardNodeF], evaluateAdd[T](evaluator)),
            case(LiteralNode, evaluateLiteral)
        );

Now we can rewrite evaluateExtended to use evaluate:

def evaluateExtended fun[T] => (evaluator: Function[T, Double] =>
    fun(node: ExtendedNode[T]) : Double =>
        match(node,
            case(StandardNode[T], evaluate[T](evaluator)),
            case(NegateNode[T], evaluateNegate[T](evaluateNegate))
        );

If we want to call evaluate or evaluateExtended we need to use a similar trick as with StandardNode and ExtendedNode to instantiate the functions:

def evaluateF fun(node: StandardNodeF) =>
    evaluate[StandardNodeF](evaluateF)(node);
    
def evaluateExtendedF fun(node: ExtendedNodeF) =>
    evaluateExtended[ExtendedNodeF](evaluateExtendedF)(node);

Hopefully you can now see how you'd extend the solution to include further node types. Although not covered here, it's also possible to create functions or classes to help combine evaluators, and functions generally written in this style with a bit less boilerplate.

If we imagine an ideal solution to the expression problem, we might argue that this solution is a little verbose, and I'd be inclined to agree. The question is: is it unnecessarily verbose? There's an argument to be made that this exposes the essential complexity of solving the expression problem. Other less verbose solutions hide rather than remove this complexity. On the one hand, this allows one to express the same ideas more succinctly without being cluttered with the machinery of how the solution is achieved, compared to the solution I just described where we have to constantly pass around the type parameter T and evaluator argument. On the other hand, if you want to understand what's going on, you don't have to look very far since everything is explicitly passed around.

On the whole, I think it's simpler than some solutions I've seen to the expression problem, and the verbosity isn't all-consuming. Pretty good for a first go, I reckon.

Topics: Language design, Shed

Thoughts? Comments? Feel free to drop me an email at hello@zwobble.org. You can also find me on Twitter as @zwobble.