Scalar
import chalk
from IPython.display import SVG
from mt_diagrams.autodiff_draw import draw_boxes
from mt_diagrams.show_expression import make_graph
import minitorch
from minitorch import Context, ScalarFunction
chalk.set_svg_draw_height(300)
In last section, we discussed two ways to compute derivatives. Symbolic derivatives require access to the full symbolic function, whereas numerical derivatives require only a black-box function. The first is precise but rigid, whereas the second is imprecise but more flexible. This module introduces a third approach known as autodifferentiation which is a tradeoff between symbolic and numerical methods.
Autodifferentiation works by collecting information about the computation path used within the function, and then transforming this information into a procedure for computing derivatives. Unlike the black-box method, autodifferentiation will allow us to use this information to compute each step more precisely.
However, in order to collect the information about the computation path, we need to track the internal computation of the function. This can be hard to do since Python does not expose how its inputs are used in the function directly: all we get is the output only. This doc describes one method for tracking computation.
Overriding Numbers¶
Since we do not have access to the underlying language interpreter, we are going to build a system to track the mathematical operations applied to each number.
- Replace all numbers with proxy a class, which we will call
Scalar
. - Replace all mathematical functions with proxy operators.
- Remember what operators were applied to each Scalar.
Consider the following code which shows the result of this approach.
x1 = minitorch.Scalar(10)
x2 = minitorch.Scalar(30)
y = x1 + x2
y.history
ScalarHistory(last_fn=<class 'minitorch.scalar_functions.Add'>, ctx=Context(no_grad=False, saved_values=()), inputs=[Scalar(10.000000), Scalar(30.000000)])
Scalar should behave exactly like numbers. The goal is that the user cannot tell the difference. But we will utilize the extra information to implement the operations we need.
Functions¶
When working with these new number we restrict ourselves to use a small set of mathematical functions $f$ of one or two arguments. Graphically, we will think of functions as little boxes. For example, a one-argument function would look like this,
draw_boxes(["$x$", "$f(x)$"], [1])
Internally, the box unwraps the content of $x$, manipulates it, and returns a new value with the saved history. We can chain together two of these functions to produce more complex functions.
draw_boxes(["$x$", "$g(x)$", "$f(g(x))$"], [1, 1])
Similarly, a two-argument function unwraps the content of both inputs $x$ and $y$, manipulates them, and returns a new wrapped version:
draw_boxes([("$x$", "$y$"), "$f(x, y)$"], [1])
Finally we can create more complex functions that chain these together in various ways.
draw_boxes([("", ""), ("", ""), ("", ""), ""], [1, 2, 1], lr=True)
Implementation¶
We will implement tracking using the
Scalar
class. It wraps a
single number (which is stored as an attribute) and
its history.
x = minitorch.Scalar(10)
To implement functions there is a corresponding class
ScalarFunction
. We will need
to reimplement each mathematical function that we would like to use by
inheriting from this class.
For example, say our function is Neg
, $g(x) = -x$
class Neg(ScalarFunction):
@staticmethod
def forward(ctx, x):
return -x
Or, say the function is Mul
, $f(x, y) = x \times y$ that
multiplies x by y
class Mul(ScalarFunction):
@staticmethod
def forward(ctx: Context, x: float, y: float) -> float:
return x * y
Within the forward function, $x$ and $y$ are always unwrapped numbers. Forward function processes and returns unwrapped values.
If we have scalars $x, y$, we can apply the above function by
z = Neg.apply(x)
out = Neg.apply(z)
# or
out2 = Mul.apply(x, z)
Note, that we do not call forward directly but instead apply. Internally 'apply' converts the inputs to standard numbers to call forward, and then wraps the output float with the history it needs.
print(out.history)
ScalarHistory(last_fn=<class '__main__.Neg'>, ctx=Context(no_grad=False, saved_values=()), inputs=[Scalar(-10.000000)])
Here out
has remembered the graph that led to its
creation.
draw_boxes(["$x$", "$g(x)$", "$f(g(x))$"], [1, 1])
Minitorch includes a library to allow you to draw these box diagrams for arbitrarily complex functions.
out.name = "out"
SVG(make_graph(out, lr=True))
out2.name = "out"
SVG(make_graph(out2, lr=True))
Operators¶
There is still one minor issue. This is what our code looks like to use [Mul]{.title-ref},
out2 = Mul.apply(x, y)
It is a bit annoying to write code this way. Also, we promised that we would have functions that look just like the Python operators we are used to writing.
To get around this issue, we need to augment the
Scalar
class so that it can
behave normally under standard mathematical operations. Instead of
calling regular multiplication, Python will call our mul
. Once this is achieved, we will
have the ability to record and track how $x$ is used in the Function,
while still being able to write
out2 = x * y
To achieve this, the class needs to provide syntax that makes it appear like a number when in use. You can read emulating numeric types to learn how this could be done.