import math
import chalk
from mt_diagrams.autodiff_draw import draw_boxes
from mt_diagrams.plots import plot_function
import minitorch
from minitorch import ScalarFunction
We are going to utilize the computation graph as a way to automatically compute derivatives of arbitrary python functions. The trick behind this autodifferentiation is to implement the derivative of each invidual function, and then utilize the chain rule to compute a derivative for any scale value.
Our computation graph was made up of individual atomic functions $f(x)$. For each of these functions we are now going to implement a backward method to provide this local derivative information.
The API for backward is to compute $d \cdot f'(x)$ where $f'(x)$ is the derivative of the function and $d$ is a value passed to backward (discussed more below).
For a simple function $f(x) = -x$, we can consult our derivative rules and get $f'(x) = -1$. Therefore the backward is,
class Neg(ScalarFunction):
def forward(ctx, x):
return -x
def backward(ctx, d):
f_prime = -1
return f_prime * d
draw_boxes(["$d \\cdot f'(x)$", "$d$"], [1], lr=False)
Note that backward works a bit different than the
mathematical notation. Sometimes the function for the derivative $f'(x)$
depends directly on x; however, backward does not take $x$
as an argument. This is where the context arguments ctx
comes in.
Consider a function Sin, $f(x) = \sin(x)$ which has derivative $f'(x) = \cos(x)$. We need to write it in code as,
class Sin(ScalarFunction):
def forward(ctx, x):
return math.sin(x)
def backward(ctx, d):
(x,) = ctx.saved_values
f_prime = math.cos(x)
return f_prime * d
def d_call(x):
ctx = minitorch.Context()
Sin.forward(ctx, x)
return Sin.backward(ctx, 1)
plot_function("f(x) = sin(x)", lambda x: Sin.apply(x).data)
plot_function("1 * f'(x) = cos(x)", d_call)
For functions that take multiple arguments, backward returns derivatives with respect to each input argument. For example, if the function computes $f(x, y)$, we need to return $f'_x(x, y)$ and $f'_y(x, y)$
class Mul(ScalarFunction):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x * y
def backward(ctx, d):
# Compute f'_x(x, y) * d, f'_y(x, y) * d
x, y = ctx.saved_values
f_x_prime = y
f_y_prime = x
return f_x_prime * d, f_y_prime * d
draw_boxes([("$d \\cdot f_x'(x, y)$", "$d \\cdot f_y'(x, y)$"), "$d$"], [1], lr=False)
Chain RuleĀ¶
This section discusses implementation of the chain rule for univariate differentiation. Before reading, review the mathematical definition of Chain Rule.
Computing backward gives a way to compute the derivative for simple functions, but what if we have more complex functions? Let's go through each of the different cases to compute the derivatives.
- One argument
- Two argument
- Same argument
One argument
Let us say that we have a complex function $h(x) = f(g(x))$. We want to compute $h'(x)$. For simplicity we use $z = g(x)$, and draw $h$ as two boxes left to right.
draw_boxes(["$x$", "$z = g(x)$", "$f(g(x))$"], [1, 1])
The chain rule tell us how to compute this term. Specifically it gives the following formula.
$$\begin{aligned} \begin{eqnarray*} d &=& 1 \cdot f'(z) \\ h'_x(x) &=& d \cdot g'(x) \\ \end{eqnarray*} \end{aligned}$$
The above derivative function tells us to compute the derivative of the right-most function ($f$), and then multiply it by the derivative of the left function ($g$).
Here is where the perspective of thinking of functions as boxes pays off. We simply reverse the order.
draw_boxes(["$d\\cdot g'(x)$", "$f'(z)$", "$1$"], [1, 1], lr=False)
The $d$ multiplier passed to backward of the first box (left) should be the value returned by [backward]{.title-ref} of the second box. The 1 at the end is to start off the chain rule process with a value for $d_{out}$.
Two arguments
Next is the case of a two argument function. We will write this as $h(x, y) = f(g(x, y))$ where $z = g(x,y)$.
draw_boxes([("$x$", "$y$"), "$z = g(x, y)$", "$h(x,y))$"], [1, 1])
Applying the chain rule we get the following equations.
$$\begin{aligned} \begin{eqnarray*} d &=& 1 \cdot f'(z) \\ h'_x(x, y) &=& d \cdot g'_x(x, y) \\ h'_y(x, y) &=& d \cdot g'_y(x, y) \\ \end{eqnarray*} \end{aligned}$$
Drawing this again with boxes.
[("$d \\cdot g'_x(x, y)$", "$d \\cdot g'_y(x, y)$"), "$f'(z)$", "$1$"],
[1, 1],
Note that this shows that the second box ($f$) does not care how many arguments the first box ($g$) has, as long as it passes back $d$ which is enough for the chain rule to work.
Multiple Uses
Finally, what happens when 1 value is used by two future boxes? Next is the case of a two argument function. We will write this as $h(x) = f(z_1, z_2)$ where $z_1 = z_2 = g(x)$.
draw_boxes(["$x$", ("$z_1 = g(x)$", "$z_2 = g(x)$"), "$h(x))$"], [1, 1])
Derivatives are linear, so the $d$ term that comes from the second box is just the sum of the two individual derivatives.
$$\begin{aligned} \begin{eqnarray*} d &=& 1 \cdot f'_{z_1}(z_1, z_2) + 1 \cdot f'_{z_2}(z_1, z_2) \\ h'_x(x) &=& d \cdot g'_x(x) \\ \end{eqnarray*} \end{aligned}$$
Specifically in terms of boxes, this means that if an output is used multiple times, we should sum together the derivative terms. This rule is important, as it means that we cannot call backward until we have aggregated together all the values that we need to calculate $d$.
["$d \\cdot g'_x(x)$", ("$f'_{z_1}(z_1, z_2)$", "$f'_{z_2}(z_1, z_2)$"), "$h1$"],
[1, 1],