A Simply Typed Lambda-Calculus for Forward Automatic Differentiation

I am proud to announce that my first computer science paper has been accepted to MFPS XXVIII! In this post I would like to give a down-to-earth introduction to the subject of the paper and to explain the problem it is attempting to solve.

I am working on design and implementation of a functional programming language with support for automatic differentiation. Automatic differentiation (AD) is a powerful technique for computing derivatives of functions given by programs in programming languages. Let us contrast AD with two other techniques for programmatically computing derivatives of functions.

First, there is numerical differentiation, which approximates the derivative of a function f by the Newton’s difference quotient

\displaystyle f'(x)\approx\frac{f(x+h)-f(x)}{h}

for a small value of h. The choice of a suitable h is a non-trivial problem because of the intricacies of floating point arithmetic. If h is too small, you are going to subtract two nearly equal numbers, which may cause extreme loss of accuracy. In fact, due to rounding errors, the difference in the numerator is going to be zero if h is small enough. On the other hand, if h is not sufficiently small, then the difference quotient is a bad estimate on the derivative.

Second, you are probably familiar with symbolic differentiation, which works by applying the rules for computing derivatives you have learned in your calculus course (sum rule, product rule, chain rule) and by using a table of derivatives of elementary functions:

\displaystyle  \begin{array}{rcl} (f + g)'(x) & = & f'(x) + g'(x) \\ (f\,\cdot\, g)'(x) & = & f'(x)\cdot g(x) + f(x)\cdot g'(x)\\ (f\circ g)'(x) & = & f'(g(x))\cdot g'(x)\\ \exp'\,(x) & = & \exp\,(x)\\ \log'\,(x) & = & 1/x\\ \sin'\,(x) & = & \cos\,(x)\\ \cos'\,(x) & = & -\sin\,(x)\\ & \dots & \end{array}

Unlike numerical differentiation, symbolical differentiation is exact. However, there are also two main drawbacks to the symbolic approach to differentiation. First, symbolic differentiation requires access to the source code of the computation, and places restrictions on that source code. Second, more importantly, symbolic differentiation suffers from the loss of sharing. What does it mean? I find the following example illuminating. Consider the problem of computing the derivative of a product of n functions: f(x) = f_1(x) \cdot f_2(x) \cdot \ldots \cdot f_n(x). Applying the product rule, we arrive at the expression for the derivative, which has size quadratic in n:

\displaystyle  \displaystyle \begin{array}{rcl}f'(x) & = & f'_1(x)\cdot f_2(x)\cdot\ldots\cdot f_n(x)\\ & + & f_1(x)\cdot f'_2(x)\cdot\ldots\cdot f_n(x)\\ & + & \dots \\ & + & f_1(x)\cdot f_2(x)\cdot\ldots\cdot f'_n(x). \end{array}

Evaluating it naively would result in evaluating each f_i(x) n-1 times. A more sophisticated implementation may try, for example, to eliminate common subexpression before evaluation, but AD elegantly eliminates (hey, a pun!) the necessity of this step.

AD simultaneously manipulates values and derivatives, leading to more sharing of the different instances of the derivative of a subterm in the computation of the derivative of a bigger term. Unlike numerical and symbolic differentiation, AD is exact: there are no rounding errors, and in fact the answer produced by AD coincides with that produced by symbolic differentiation. Furthermore, AD is efficient: it offers strong complexity guarantees (in particular, evaluation of the derivative takes no more than a constant factor times as many operations as evaluation of the function).

AD comes in several variations: forward mode, reverse mode, as well as mixtures thereof. I will only focus on forward mode AD here.

Forward mode AD makes use of dual numbers. Let us recall the definition. It resembles that of complex numbers, with the difference that dual numbers are obtained from real numbers by adjoining a symbol \varepsilon with the property that \varepsilon^2=0. Thus, the set \mathbb{D} of dual numbers is the set of pairs of real numbers, each pair being written as a formal sum a+a'\varepsilon. The arithmetic operations on dual numbers are defined by the formulas

\displaystyle  \begin{array}{rcl} (a + a'\varepsilon) + (b + b'\varepsilon) & = & (a + b) + (a' + b')\varepsilon,\\ (a + a'\varepsilon)\times(b + b'\varepsilon) & = & (a \times b) + (a\times b' + a'\times b)\varepsilon \end{array}

(and similar formulas for subtraction and division).

We refer to the coefficients a and a' as to the primal value and perturbation respectively. A dual number a+a'\varepsilon can be thought as a small perturbation of the real number a.

Every differentiable function f: \mathbb{R}\to\mathbb{R} induces a function f_* : \mathbb{D}\to\mathbb{D}, called the pushforward of f and defined by the formula

\displaystyle f_*(x + x'\varepsilon) = f(x) + f'(x)x'\varepsilon,

which is essentially the formal Taylor series of f truncated at degree 1. It follows from the properties of differentiation that the pushforward transform commutes with sum, product, and composition, as well as with subtraction and division.

OK, so here is the idea of forward mode AD: overload arithmetic operations and elementary mathematical functions to operate not only on real numbers but also on dual numbers. In particular, the extension of each basis function from real numbers to dual numbers is given by the pushforward of that function:

data D = D Double Double

instance Num D where
    D x x' + D y y' = D (x + y) (x' + y')
    D x x' * D y y' = D (x * y) (x * y' + x' * y)
    negate (D x x') = D (negate x) (negate x')
instance Floating D where
    exp (D x x') = D (exp x) (exp x * x')
    log (D x x') = D (log x) (x' / x)
    sin (D x x') = D (sin x) (cos x * x')

Here in the example I am using Haskell because it supports overloading and because I am more familiar with it. But the same trick can be performed in any language supporting overloading (e.g., Scheme or C++).

What overloading achieves is that any function built out of the overloaded primitives contains information about its own derivative. Any such function is generic: it can accept arguments of both types, real numbers as well as dual numbers. Effectively, any function in the program can compute two mathematically different functions: the primal function when given an argument that is a real number, and its pushforward when given an argument that is a dual number. This follows from the properties of pushforward.

This allows us to compute the derivative of f at some point x by simply evaluating f at the point x+\varepsilon and extracting the coefficient in front of \varepsilon. For example, define

f :: Floating a => a -> a
f z = sqrt (3 * sin z)

and try it out in GHCi:

*Main> f (D 2 1)
D 1.6516332160855343 (-0.3779412091869595)

More formally, define the derivative operator D by D\, f\, x = E\, f(x+\varepsilon), where the accessor E is defined by E\, (x + x'\varepsilon) = x'.

Many applications of AD require that derivative operators nest properly. In other words, we would like to be able to differentiate functions that internally differentiate some other functions, and get correct results! For example, you may want to use AD for bi-level optimization, when the values of the function you are optimizing are found as optima of some other function.

Let us look at the example:

\displaystyle D\,(\lambda x.\, x\times (D\,(\lambda y.\, x\times y)\,2))\,1.

Note that the inner function that is being differentiated depends on x, which is free in the body of the function, but is bound by the enclosing lambda.

It is easy to compute the correct answer: the inner function being differentiated is linear in y, and its derivative is x at any point, therefore the outer function becomes x^2 and its derivative at 1 is 2.

On the other hand, what does our formalism produce? Let us compute. To visually distinguish between different invocations of E I am going to use braces of different shape:

\displaystyle  \begin{array}{cl} & D\,(\lambda x.\, x\times (D\,(\lambda y.\, x\times y)\,2))\,1\\ = & E\,[(\lambda x.\, x\times (D\, (\lambda y.\, x\times y)\,2)) (1+\varepsilon)]\\ = & E\,[(1+\varepsilon)\times (D\, (\lambda y.\, (1+\varepsilon)\times y)\, 2)]\\ = & E\,[(1+\varepsilon)\times E\,\{(\lambda y.\, (1+\varepsilon)\times y) (2+\varepsilon)\}]\\ = & E\,[(1+\varepsilon)\times E\,\{(1+\varepsilon)\times  (2+\varepsilon)\}]\\ = & E\,[(1+\varepsilon)\times E\,\{2 + 3\varepsilon\}]\\ = & E\,[(1+\varepsilon) \times 3]\\ = & 3 \ne 2 \end{array}

Uh-oh! What happened?

The root of this error, known as “perturbation confusion”, is our failure to distinguish between the perturbations introduced by the inner and outer invocations of D. There are ways to solve this problem, for instance by tagging perturbations with a fresh \varepsilon every time D is invoked and incurring the bookkeeping overhead of keeping track of which \varepsilon is associated with which invocation of D. Nonetheless, I hope the example serves to illustrate the value and nontriviality of a clear semantics for a \lambda-calculus with AD. In the paper I make some first steps towards tackling this problem.