Auto-Grad
import chalk
from chalk import hcat, latex, vstrut
from mt_diagrams.drawing import papaya
from mt_diagrams.tensor_draw import left_arrow, matrix
from minitorch import Scalar, tensor
chalk.set_svg_draw_height(300)
#
Next, we consider autodifferentiation in the tensor framework. We have now moved from scalars and derivatives to vectors, matrices, and tensors. This means multivariate calculus can bring into play some more terminology. However, most of what we actually do does not require complicated terminology or much technical math. In fact, except some name changes, we have already built almost everything we need in Module 1.
The key idea is, just as we had Scalar
and ScalarFunction
, we need to
construct Tensor
and TensorFunction
(which we just call Function
).
These new objects behave very similar to their counterparts:
a) Tensors cannot be operated on directly, but need to be transformed through
a function.
b) Functions must implement both forward
and backward
.
c) These transformations are tracked, which allow backpropagation through
the chain rule.
All of this machinery should work out of the box.
The main new terminology to know is gradient. Just as a tensor is a multidimensional array of scalars, a gradient is a multidimensional array of derivatives for these scalars. Consider the following code::
Scalar auto-derivative notation
def f(a, b, c):
return a + b + c
a, b, c = Scalar(1), Scalar(2), Scalar(3)
out = f(a, b, c)
out.backward()
a.derivative, b.derivative, c.derivative
(1.0, 1.0, 1.0)
Tensor auto-gradient notation
tensor1 = tensor([1, 2, 3])
out = tensor1.sum()
out.backward()
# shape (3,)
tensor1.grad
[1.00 1.00 1.00]
tensor1.grad.shape
(3,)
The gradient of tensor1
is a tensor that holds the derivatives of
each of its elements. Another place that gradients come into play is
that backward
no longer takes $d_{out}$ as an argument, but now
takes $grad_{out}$ which is just a tensor consisting of all the
$d_{out}$.
Note: You will find lots of different notation for gradients and multivariate terminology. For this Module, you are supposed to ignore it and stick to everything you know about derivatives. It turns out that you can do most of machine learning without ever thinking in higher dimensions.
If you think about gradient and $grad_{out}$ in this way (i.e. tensors of derivatives and $d_{out}$), then you can see how we can easily compute the gradient for tensor operations using univariate rules.
- map. Given a tensor,
map
applies a univariate operation to each scalar position individually. For a scalar $x$, consider computing $g(x)$. From Module 1, we know that the derivative of $f(g(x))$ is equal to $g'(x) \times d_{out}$. To compute the gradient inbackward
, we only need to compute the derivative for each scalar position and then apply amul
map.
def label(t, d):
return vstrut(0.5) // latex(t) // vstrut(0.5) // d
opts = chalk.ArrowOpts(arc_height=-0.5)
opts2 = chalk.ArrowOpts(arc_height=-0.2)
d = hcat(
[
label("$f'_x(g(x))$", matrix(3, 2, "a")),
left_arrow,
label("$g'(x)$", matrix(3, 2, "b")),
label("$d_{\\text{out}}$", matrix(3, 2, "g", colormap=lambda i, j: papaya)),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts2).connect(
("b", 1, 0), ("a", 1, 0), opts2
).connect(("g", 0, 0), ("b", 0, 0), opts).connect(("g", 1, 0), ("b", 1, 0), opts)
- zip. Given two tensors,
zip
applies a binary operation to each pair of scalars. For two scalars $x$ and $y$, consider computing $g(x, y)$. From Module 1, we know that the derivative of $f(g(x, y))$ is equal to $g_x'(x, y) \times d_{out}$ and $g_y'(x, y) \times d_{out}$. Thus to compute the gradient, we only need to compute the derivative for each scalar position and apply amul
map.
d = hcat(
[
matrix(3, 2, "a1"),
matrix(3, 2, "a"),
left_arrow,
matrix(3, 2, "b"),
matrix(3, 2, "g", colormap=lambda i, j: papaya),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts).connect(
("b", 1, 0), ("a", 1, 0), opts
).connect(("g", 0, 0), ("b", 0, 0), opts)
- reduce. Given a tensor,
reduce
applies an aggregation operation to one dimension. For simplicity, let's consider sum-based reductions. For scalars $x_1$ to $x_n$, consider computing $x_1 + x_2 + \ldots + x_n$. For any $x_i$ value, the derivative is 1. Therefore, the derivative for any position computed inbackward
is simply $d_{out}$. This means to compute the gradient, we only need to send $d_{out}$ to each position. (For other reduce operations such asproduct
, you get different expansions, which can be calculated just by taking derivatives).
d = hcat(
[
matrix(3, 2, "a"),
left_arrow,
matrix(3, 2, "b"),
matrix(3, 2, "g", colormap=lambda i, j: papaya),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0)).connect(("b", 1, 0), ("a", 1, 0)).connect(
("g", 0, 0), ("b", 0, 0)
)