Automatic Differentiation with simple Python implementation
Published:
Explaining automatic differentiation with a simple example and easy-to-read Python scripts which can be downloaded here: forward-mode and reverse-mode.
This post explains how forward- and reverse-mode automatic differentiation algorithms work with the Wikipedia example $f(x_1, x_2) = x_1 * x_2 + \sin(x_1)$ and Python implementations. Autodiff is conceptually easy to understand, but untrivial to implement. Therefore, unless you can write your own autodiff code, you cannot say confidently that you understand how it works.
If you have some basic understanding of autodiff, you can skip my writing and jump to the implementation.
The first step is to divide $f$ into operation units labeled by $w^{n}_i$, where $n$ stands for hierarchy, and $i$ stands for the index of the members in the same hierarchy. The lowest hierarcy corresponds to the variables, i.e., $x_1$ and $x_2$, and the highest hierarchy corresponds to the function, i.e., $f$. Therefore, we have
\[w^2 = w^1_1 + w^1_2;\\ w^1_1 = w^0_1 * w^0_2, \quad w^1_2 = \sin(w^0_1); \\ w^0_1 = x_1, \quad w^0_2 = x_2.\]Forward mode
Suppose we are interested in $\partial f/\partial x_1$ and we want to fix $x_2$. The forward-mode corresponds to evaluating $\partial f/\partial x_1$ from bottom to top. We use $\dot{w} = \partial w/\partial x_1$ and work all the way up:
\[\dot{w}_1^0 = 1, \quad \dot{w}_2^0 = 0; \\ \dot{w}_1^1 = \frac{\partial w^1_1}{\partial w_1^0} \dot{w}_1^0 + \frac{\partial w^1_1}{\partial w_2^0} \dot{w}_2^0, \quad \dot{w}_2^1 = \frac{\partial w^1_2}{\partial w_1^0} \dot{w}_1^0 + \frac{\partial w^1_2}{\partial w_2^0} \dot{w}_2^0 \\ \dot{w}^2 = \frac{\partial w^2}{\partial w_1^1} \dot{w}_1^1 + \frac{\partial w^2}{\partial w_2^1} \dot{w}_2^1\]Therefore, we need to keep track of the gradient of all operation units and build them up. Because many gradient forms involve the value of the variable, e.g. $\partial (x_1 x_2)/\partial x_1 = x_2$, we also need to keep track of the values of the operation units.
We we need the derivative with respect to $x_2$ we set $\dot{w}_1^0 = 0, \dot{w}_2^0 = 1$ and redo the above. Therefore, the forward-mode requires $\mathcal{O}(N M)$ operations to evaluate the full derivative, where $N$ is the number of variables, and $M$ is the number of operation units.
Implementation
We use a class to keep track of the value of $w_i^n$ and the gradient $\dot{w}_i^n$, and we have to define the operation units, because now the operations not only returns the value, but also the gradient. A vanilla python implementation is provided in the following.
We first define the data structure fw_var
:
'''
Forward-mode automatic differentiation for the function:
f = x1*x2 + sin(x1)
'''
import numpy as np
class fw_var():
# define a new data structure to store both value and gradient
def __init__(self, v=0, g=1):
self.val = v # value
self.grad = g # gradient
Next we define the operation units.
def f_prod(x, y):
# evaluate x * y
var = fw_var()
var.val = x.val * y.val
var.grad = x.val * y.grad + x.grad * y.val
return var
def f_add(x, y):
# evaluate x + y
var = fw_var()
var.val = x.val + y.val
var.grad = x.grad + y.grad
return var
And finally we build $f(x_1, x_2)$ and evaluate $\partial f/\partial x_1$.
def func(x1, x2):
# The function to be evaluated
# f = x1*x2 + sin(x1)
return f_add(f_prod(x1, x2), f_sin(x1))
x1 = fw_var(0, 1)
x2 = fw_var(1, 0)
var = func(x1, x2)
print(f"Value of f: {var.val}")
print(f"Grad wrt x1: {var.grad}")
One can also redefine the +
, -
, *
and /
signs by adding __add__()
, __sub__
, __mul__()
and __truediv__()
members to the class fw_var
. You can find many fancier versions online :)
Reverse mode
The reverse mode is not as straightforward as the forward-mode. However, it requires less number of operations although more memory. We start from the hierarchy again
\[w^2 = w^1_1 + w^1_2;\\ w^1_1 = w^0_1 * w^0_2, \quad w^1_2 = \sin(w^0_1); \\ w^0_1 = x_1, \quad w^0_2 = x_2.\]Now we work from top to bottom, and the value we care is $\tilde{w}_i^n = \partial f/\partial w_i^n$. The routine is
\[\tilde{w}^2 = \frac{\partial f}{\partial w^2} = 1 \\ \tilde{w}^1_1 = \tilde{w}^2 \frac{\partial w^2 }{\partial w^1_1}, \quad \tilde{w}^1_2 = \tilde{w}^2 \frac{\partial w^2 }{\partial w^1_2}; \\ \tilde{w}^0_1 = \tilde{w}^1_1 \frac{\partial w^1_1 }{\partial w^0_1} + \tilde{w}^1_2 \frac{\partial w^1_2}{\partial w^0_1} = \frac{\partial f}{\partial w^0_1}, \quad \tilde{w}^0_2 = \tilde{w}^1_1 \frac{\partial w^1_1}{\partial w^0_2} + \tilde{w}^1_2 \frac{\partial w^1_2}{\partial w^0_2} = \frac{\partial f}{\partial w^0_2}\]Now we see that, in order to evaluate $\tilde{w}^k$, we need the information of $\tilde{w}^{k+1}$ as well as the coefficients $\partial w^{k+1}/ \partial w^k$.
We call $w^{k+1}$ parents, and $w^k$ children. For one child, we need to find all of its parents, and the coefficient of each parent, and so on. However, it is not easy. The easy direction is given a parent, find all of its children recursively. Lastly, $x_1$ and $x_2$ have no children, thus the recursion ends.
Implementation
First, we define a slightly more complicated data structure re_var
, where self.children
is a list of tuples (coeff, child)
, the coefficient correspond to $\partial \rm parent/\partial \rm child$.
import numpy as np
class re_var():
# the data structure for reverse mode
def __init__(self, v=1, children=[]):
self.val = v # value of the variable
self.children = children # list of tuples (coeff, re_var)
self.grad = 0 # gradient
Then we define the unit operations
def f_add(x, y):
# x + y
var = re_var()
var.val = x.val + y.val
var.children = [(1, x), (1, y)] # child x with coefficient 1, and y with 1
return var
def f_prod(x, y):
# x * y
var = re_var()
var.val = x.val * y.val
var.children = [(y.val, x), (x.val, y)]
return var
def f_sin(x):
# sin(x)
var = re_var()
var.val = np.sin(x.val)
var.children = [(np.cos(x.val), x)]
return var
Next we define a recursive function that visits all the offsprings of a given re_var
object and accumulate the corresponding coefficients to the offsprings’ gradient values.
def calc_grad(var, acc_grad=1):
# var: the unit at the top hierarchy
# acc_grad: accumulated gradient from parents
# evaluate the gradient recursively, so every unit has a gradient
var.grad += acc_grad
for coef, child in var.children:
calc_grad(child, coef * acc_grad)
Finally, we define $x_1$, $x_2$ and $f(x_1, x_2)$, and evaluate the gradients.
def func(x1, x2):
return f_add(f_prod(x1, x2), f_sin(x1))
# define x1 and x2
x1 = re_var(0)
x2 = re_var(1)
var = func(x1, x2)
calc_grad(var)
print("value: ", var.val)
print("gradient wrt x1: ", x1.grad)
print("gradient wrt x2: ", x2.grad)
The complexity is $\mathcal{O}(M+N)$, which is much smaller than $\mathcal{O}(MN)$ for the forward mode.