CPS, or continuation-passing style, is an intermediate representation for programs, particularly functional programs. It’s used in compilers for languages such as SML and Scheme.
In CPS, there are two rules: first, that function/operator arguments must always be trivial; second, that function calls do not return. From this, a lot falls out.
In this post, we’ll introduce CPS by building a simple (Plotkin1) CPS transform from a small Scheme-like language. We’ll sketch some optimizations on the IR. Then we’ll look at a couple of the common ways to actually compile the IR for execution.
We have integers: 5
We have some operations on the integers: (+ 1 2)
, (< 3 4)
(returns 1 or 0)
We can bind variables: (let ((x 1)) x)
/ (letrec ...)
?
We can create single-parameter functions2: (lambda (x) (+ x 1))
and they can close over variables
We can call functions: (f x)
We can branch: (if (< x y) x y)
(where we have decided to use 0 and 1 as
booleans)
We’re going to implement a recursive function called cps
incrementally,
starting with the easy forms of the language and working up from there. Many
people like implementing the compiler both in Scheme and for Scheme but I find
that all the quasiquoting makes everything fussier than it should be and
obscures the lesson, so we’re doing it in Python.
This means we have a nice clear separation of code and data. Our Python code is the compiler and we’ll lean on Python lists for S-expressions. Here’s what some sample Scheme programs might look like as Python lists:
5
["+", 1, 2]
["let", [["x", 1]], "x"]
["lambda", ["x"], ["+", "x", 1]]
Our cps
function will take two arguments. The first argument, exp
, is the
expression to compile. The second argument, k
, is a continuation. We have
to do something with our values, but CPS requires that functions never
return. So what do we do? Call another function, of course.
This means that the top-level invocation of cps
will be passed some useful
top-level continuation like print-to-screen
or write-to-file
. All child
invocations of cps
will be passed either that continuation, a manufactured
continuation, or a continuation variable.
cps(["+", 1, 2], "$print-to-screen")
# ...or...
cps(["+", 1, 2], ["cont", ["v"], ...])
So a continuation is just another function. Kind of.
While you totally can generate real first-class functions for use as continuations, it can often be useful to partition your CPS IR by separating them. All real (user) functions will take a continuation as a last parameter—for handing off their return values—and can arbitrarily escape, whereas all continuations are generated and allocated/freed in a stack-like manner. (We could even implement them using a native stack if we wanted. See “Partitioned CPS” and “Recovering the stack” from Might’s page.)
For this reason we syntactically distinguish IR function forms ["fun", ["x",
"k"], ...]
from IR continuation forms ["cont", ["x"], ...]
. Similarly, we
distinguish function calls ["f", "x"]
from continuation calls ["$call-cont",
"k", "x"]
(where $call-cont
is a special form known to the compiler).
Let’s look at how we compile integers into CPS:
def cps(exp, k):
match exp:
case int(_):
return ["$call-cont", k, exp]
raise NotImplementedError(type(exp)) # TODO
cps(5, "k") # ["$call-cont", "k", 5]
Integers satisfy the trivial requirement. So does all constant data (if we had strings, floats, etc), variable references, and even lambda expressions. None of these require recursive evaluation, which is the core of the triviality requirement. All of this requires that our nested AST get linearized into a sequence of small operations.
Variables are similarly straightforward to compile. We leave the variable names as-is for now in our IR, so we need not keep an environment parameter around.
def cps(exp, k):
match exp:
case int(_) | str(_):
return ["$call-cont", k, exp]
raise NotImplementedError(type(exp)) # TODO
cps("x", "k") # ["$call-cont", "k", "x"]
Now let’s look at function calls. Function calls are the first type of
expression that requires recursively evaluating subexpressions. To evaluate (f
x)
, for example, we evaluate f
, then x
, then do a function call. The order
of evaluation is not important to this post; it is a semantic property of the
language being compiled.
To convert to CPS, we have to both do recursive compilation of the arguments and also synthesize our first continuations!
To evaluate a subexpression, which could be arbitrarily complex, we have to
make a recursive call to cps
. Unlike normal compilers, this doesn’t return a
value. Instead, you pass it a continuation (does the word “callback” help
here?) to do future work when that value has a name. To generate
compiler-internal names, we have a gensym
function that isn’t interesting and
returns unique strings.
The only thing that differentiates it from, for example, a JavaScript callback, is that it’s not a Python function but instead a function in the generated code.
def cps(exp, k):
match exp:
case [func, arg]:
vfunc = gensym()
varg = gensym()
return cps(func, ["cont", [vfunc],
cps(arg, ["cont", [varg],
[vfunc, varg, k]])])
# ...
cps(["f", 1], "k")
# ["$call-cont", ["cont", ["v0"],
# ["$call-cont", ["cont", ["v1"],
# ["v0", "v1", "k"]],
# 1]],
# "f"]
Note that our generated function call from (f x)
now also has a continuation
argument that was not there before. This is because (f x)
does not return
anything, but instead passes the value to the given continuation.
Calls to primitive operators like +
are our other interesting case. Like
function calls, we evaluate the operands recursively and pass in an additional
continuation argument. Note that not all CPS implementations do this for simple
math operators; some choose to allow simple arithmetic to actually “return”
values.
def gensym(): ...
def cps(exp, k):
match exp:
case [op, x, y] if op in ["+", "-"]:
vx = gensym()
vy = gensym()
return cps(x, ["cont", [vx],
cps(y, ["cont", [vy],
[f"${op}", vx, vy, k]])])
# ...
cps(["+", 1, 2], "k")
# ["$call-cont", ["cont", ["v0"],
# ["$call-cont", ["cont", ["v1"],
# ["$+", "v0", "v1", "k"]],
# 2]],
# 1]
We also create a special form for the operator in our CPS IR that begins with
$
. So +
gets turned into $+
and so on. This helps distinguish operator
invocations from function calls.
Now let’s look at creating functions. Lambda expressions such as (lambda (x)
(+ x 1))
need to create a function at run-time and that function body contains
some code. To “return”, we use $call-cont
as usual, but we have to also
remember to create a new fun
form with a continuation parameter (and then
thread that through to the function body).
def cps(exp, k):
match exp:
case ["lambda", [arg], body]:
vk = gensym("k")
return ["$call-cont", k,
["fun", [arg, vk], cps(body, vk)]]
# ...
cps(["lambda", ["x"], "x"], "k")
# ["$call-cont", "k",
# ["fun", ["x", "k0"],
# ["$call-cont", "k0", "x"]]]
Alright, last in this mini language is our if
expression: (if cond iftrue
iffalse)
where all of cond
, iftrue
, and iffalse
can be arbitrarily
nested expressions. This just means we call cps
recursively three times.
We also add this new compiler builtin called ($if cond iftrue iffalse)
that takes one trivial expression—the condition—and decides which of the
branches to execute. This is roughly equivalent to a machine code conditional
jump.
The straightforward implementation works just fine, but can you see what might go wrong?
def cps(exp, k):
match exp:
case ["if", cond, iftrue, iffalse]:
vcond = gensym()
return cps(cond, ["cont", [vcond],
["$if", vcond,
cps(iftrue, k),
cps(iffalse, k)]])
# ...
cps(["if", 1, 2, 3], "k")
# ["$call-cont", ["cont", ["v0"],
# ["$if", "v0",
# ["$call-cont", "k", 2],
# ["$call-cont", "k", 3]]],
# 1]
The problem is that our continuation, k
, need not be a continuation
variable—it could be an arbitrary complicated expression. Our implementation
copies it into the compiled code twice, which in the worst case could lead to
exponential program growth. Instead, let’s bind it to a name and then use the
name twice.
def cps(exp, k):
match exp:
case ["if", cond, iftrue, iffalse]:
vcond = gensym()
vk = gensym("k")
return ["$call-cont", ["cont", [vk],
cps(cond, ["cont", [vcond],
["$if", vcond,
cps(iftrue, vk),
cps(iffalse, vk)]])],
k]
# ...
cps(["if", 1, 2, 3], "k")
# ["$call-cont", ["cont", ["k1"],
# ["$call-cont", ["cont", ["v0"],
# ["$if", "v0",
# ["$call-cont", "k1", 2],
# ["$call-cont", "k1", 3]]],
# 1]],
# "k"]
Last, let
can be handled by using a continuation, as we’ve bound temporary
variables in previous examples. You could also handle it by desugaring it into
((lambda (x) body) value)
but that will generate a lot more administrative
overhead for your optimizer to get rid of later.
def cps(exp, k):
match exp:
case ["let", [name, value], body]:
return cps(value, ["cont", [name],
cps(body, k)])
# ...
cps(["let", ["x", 1], ["+", "x", 2]], "k")
# ['$call-cont', ['cont', ['x'],
# ['$call-cont', ['cont', ['v0'],
# ['$call-cont', ['cont', ['v1'],
# ['$+', 'v0', 'v1', 'k']],
# 2]],
# 'x']],
# 1]
There you have it. A working Mini-Scheme to CPS converter. My implementation is ~30 lines of Python code. It’s short and sweet but you might have noticed some shortcomings…
Now, you might have noticed that we’re giving names to a lot of trivial
expressions—unnecessary cont
forms used like let
bindings. Why name the
integer 3
if it’s trivial?
One approach people take to avoid this is meta-continuations, which I think
many people call the “higher-order transform”. Instead of always generating
cont
s, we can sometimes pass in a compiler-level (in this case, Python)
function instead.
See Matt Might’s article and what I think is a working Python implementation.
This approach, though occasionally harder to reason about and more complex, cuts down on a significant amount of code before it is ever emitted. For few-pass compilers, for resource-constrained environments, for enormous programs, … this makes a lot of sense.
Another approach, potentially easier to reason about, is to lean on your optimizer. You’ll probably want an optimizer anyway, so you might as well use it to cut down your intermediate code too.
Just like any IR, it’s possible to optimize by doing recursive rewrites. We won’t implement any here (for now… maybe I’ll come back to this) but will sketch out a few common ones.
The simplest one is probably constant folding, like turning (+ 3 4)
into 7
.
The CPS equivalent looks kind of like this:
["$+", "3", "4", "k"] # => ["$call-cont", "k", 7]
["$if", 1, "t", "f"] # => t
["$if", 0, "t", "f"] # => f
An especially important optimization, particularly if using the simple CPS
transformation that we’ve been using, is beta reduction. This is the process of
turning expressions like ((lambda (x) (+ x 1)) 2)
into (+ 2 1)
by
substituting the 2
for x
. For example, in CPS:
["$call-cont", ["cont", ["k1"],
["$call-cont", ["cont", ["v0"],
["$if", "v0",
["$call-cont", "k1", 2],
["$call-cont", "k1", 3]]],
1]],
"k"]
# into
["$call-cont", ["cont", ["v0"],
["$if", "v0",
["$call-cont", "k", 2],
["$call-cont", "k", 3]]],
1]
# into
["$if", 1,
["$call-cont", "k", 2],
["$call-cont", "k", 3]]
# into (via constant folding)
["$call-cont", "k", 2]
Substitution has to be scope-aware, and therefore requires threading an environment parameter through your optimizer.
As an aside: even if you “alphatise” your expressions to make them have unique variable bindings and names, you might accidentally create second bindings of the same names when substituting. For example:
# substitute(haystack, name, replacement) substitute(["+", "x", "x"], "x", ["let", ["x0", 1], "x0"])
This would create two bindings of
x0
, which violates the global uniqueness property.
You may not always want to perform this rewrite. To avoid code blowup, you may only want to substitute if the function or continuation’s parameter name appears zero or one times in the body. Or, if it occurs more than once, substitute only if the expression being substituted is an integer/variable. This is a simple heuristic that will avoid some of the worst-case scenarios but may not be maximally beneficial—it’s a local optimum.
Another thing to be aware of is that substitution may change evaluation order. So even if you only have one parameter reference, you may not want to substitute:
((lambda (f) (begin (g) f))
(do-a-side-effect))
As the program is right now, do-a-side-effect
will be called before g
and
the result will become f
. If you substitute do-a-side-effect
for f
in
your optimizer, g
will be called before do-a-side-effect
. You can be more
aggressive if your analyzer tells you what functions are side-effect free, but
otherwise… be careful with function calls.
There are also more advanced optimizations but they go beyond an introduction to CPS and I don’t feel confident enough to sketch them out.
Alright, we’ve done a bunch of CPS→CPS transformations. Now we would like to execute the optimized code. To do that, we have to transform out of CPS into something designed to be executed.
In this section we’ll list a couple of approaches to generating executable code from CPS but we won’t implement any.
You can generate naive C code pretty directly from CPS. The fun
and cont
forms become top-level C functions. In order to support closures, you need to
do free variable analysis and allocate closure structures for each.
(See also the approach in scrapscript in the section called
“functions”.) Then you can do a very generic calling convention where you pass
closures around. Unfortunately, this is not very efficient and doesn’t
guarantee tail-call elimination.
To support tail-call elimination, you can use trampolines. This mostly involves allocating a frame-like closure on the heap with each tail-call. If you have a garbage collector, this isn’t too bad; the frames do not live very long. In fact, if you instrument the factorial example in Eli’s blog post, you can see that the trampoline frames live only until the next one gets allocated.
This observation led to the development of the Cheney on the MTA
algorithm, which uses the C call stack as a young generation for the garbage
collector. It uses setjmp
and longjmp
to unwind the stack. This approach is
used in CHICKEN Scheme and Cyclone Scheme. Take a look at
Baker’s 1994 implementation.
If you don’t want to do any of this trampoline stuff, you can also do the One
Big Switch approach which stuffs each of the fun
s and cont
s into a case
in a massive switch
statement. Calls become goto
s. You can manage your
stack roots pretty easily in one contiguous array. However, as you might
imagine, larger Scheme programs might cause trouble with some C compilers.
Last, you need not generate C. You can also do your own lowering from CPS into a lower-level IR and then to some kind of assembly language.
You have seen how to produce CPS, how to optimize it, and how to eliminate it. There’s much more to learn, if you are interested. Please send me material if you find it useful.
I had originally planned to write a CPS-based optimizer and code generator for
scrapscript but I got stuck on the finer details of compiling pattern matching
to CPS. Maybe I will return to this in the future by desugaring it to nested
if
s or something.
Check out the code.
Thanks to Vaibhav Sagar and Kartik Agaram for giving feedback on this post. Thanks to Olin Shivers for an excellent course on compiling functional programming languages.
Earlier this year, my grandmother mentioned offhand that she was getting brunch with the Plotkins. I, midway through a course by Olin Shivers on compiling functional programming languages, did a double take. Surely she couldn’t mean… but yep, apparently my grandmother and Gordon Plotkin are friends! ↩
It’s a several-line change to the compiler to handle multiple parameters but for this post it’s just noise so I leave it as an exercise. ↩