import minitorch
from minitorch import ScalarFunction
import chalk
from mt_diagrams.drawing import r
from mt_diagrams.autodiff_draw import draw_boxes, backprop
chalk.set_svg_draw_height(300)
The backward
function tells us how to compute the derivative of one
operation. The chain rule tells us how to compute the derivative of two
sequential operations. In this section, we show how to use these to
compute the derivative for an arbitrary series of operations.
Practically this looks like re-running our graph in reverse order from right-to-left. However, we need to ensure that we do this in the correct order. The key implementation challenge of backpropagation is to make sure that we process each node in the correct order, i.e. we have first processed every node that uses a Variable before that varible itself.
Assume we have Variables $x,y$ and a Function $h(x,y)$. We want to compute the derivatives $h'_x(x, y)$ and $h'_y(x, y)$
We assume x, +, log, and exp are all implemented as simple functions. This means that the final output Variable has constructed a graph of its history that looks like this:
draw_boxes([("$x$", "$y$"), ("", ""), ("", ""), "$h(x, y)$"], [1, 2, 1], lr=True)
Here, starting from the left, the first arrows represent inputs $x,y$, the left node outputs $z$, the top node $\log(z)$, the bottom node $\exp(z)$ and the final right node $h(x, y)$. Forward computation proceeds left-to-right.
The chain rule tells us methods for propagating the derivatives. We
can use the rules from the previous section right-to-left until we reach
the initial Variables $x,y$, i.e. the leaf
Variables.
We could just apply these rules randomly and process each nodes as they
come aggregating the resulted values. However this can be quite inefficient.
It is better to wait to call backward
until we have accumulated all the
values we will need.
To handle this issue, we will process the nodes in topological
order
. We first note that our graph is directed and that acyclic.
Directionality comes from the backward
function, and the lack of
cycles is a consequence of the choice that every Function must create a
new variable.
The topological ordering of a directed acyclic graph is an ordering that ensures no node is processed after its ancestor, e.g. in our example that the left node cannot be processed before the top or bottom node. The ordering may not be unique, and it does not tell us whether to process the top or bottom node first.
There are several easy-to-implement algorithms for topological sorting. As graph algorithms are beyond the scope of this document, we recommend using the depth-first search algorithm described in pseudocode section of Topological Sorting.
Once we have the order defined, we process each node one at a time in order. We start the rightmost node ($h(x,y)$) with red arrow in the graph below. The starting derivative is an argument given to us.
d = backprop(1)
r(d, "figs/Autograd/backprop1.svg")
We then process the Function with the chain rule. This calls
backward
of +, and gives the derivative for the two red Variables
(which correspond to $\log(z), \exp(z)$ from the forward
pass). You need to track these intermediate red derivative values in a
dictionary.
backprop(2)
Let us assume the next Variable in the order is the top node. We have just computed and stored the necessary derivative $d_{out}$, so we can apply the chain rule. This produces a new derivative (corresponding to $z$: left red arrow below) for us to store.
backprop(3)
The next Variable in the order is the bottom node. Here we have an interesting result. We have a new arrow, but it corresponds to the same Variable ($z$) that we just computed. It is is a useful exercise to show that as a consequence of the two argument chain rule that the derivative for this Variable is the sum of each of these derivatives. Practically this means just adding it to your dictionary.
backprop(4)
After working on this Variable, at this point, all that is left in the is our input leaf Variables.
backprop(5)
When we reach the leaf variables in our order, for example $x$, we store the derivative with that variable. Since each step of this process is an application of the chain rule, we can show that this final value is $h'_x(x, y)$. The next and last step is to compute $h'_y(x, y)$.
d = backprop(6)
r(d, "figs/Autograd/backprop6.svg")
By convention, the variables $x, y$ have their derivatives stored as::
x.derivative, y.derivative
As illustrated in the graph for the above example, each of the red arrows represents a constructed derivative which eventually passed to $d_{out}$ in the chain rule. Starting from the rightmost arrow, which is passed in as an argument, backpropagate should run the following algorithm:
accumulate_derivative
)
and loop to (1)
b. if the Scalar is not a leaf,
1) call .backprop_step
on the last function with $d_{out}$
2) loop through all the Scalars+derivative produced by the chain rule
3) accumulate derivatives for the Scalar in a dictionaryFinal note: only leaf Scalars should ever have non-None
.derivative
value. All intermediate Scalars should only keep
their current derivative values in the dictionary. This is a
bit annoying, but it follows the behavior of PyTorch.