==================== Autodifferentiation ==================== .. jupyter-execute:: :hide-code: import minitorch from minitorch import ScalarFunction import sys sys.path.append("../project/") from datasets_plotly import plot_function In :doc:`scalar`, we have seen that :class:`ScalarFunction` can be used to implement Functions on Variables. We do this by calling `apply` function on Variables, which then calls the underlying `forward` function. (For notation, we still use capital letters to refer to our new Functions and Variables to distringuish them from Python functions and variables. We use functions to refer to mathmatical functions as well depending on the context.) Now, we include additional information on the class which gives the derivative of the individual function. The trick behind autodifferentiation is to use a `chain` of function calls to compute a derivative. Just like `forward` calculates the function `f(x)`, we need a `backward` method to provide this local derivative information. Backward ----------------------------- For every Function, we need to provide a `backward` method to compute its derivative information. Specifically, `backward` computes :math:`f'(x) \times d_{out}` where :math:`d_{out}` is an argument passed in (discussed below). For the simple function :math:`f(x) = x \times 5`, we can consult our derivative rules and get :math:`f'(x) = 5`. Therefore the `backward` is .. jupyter-execute:: class TimesFive(ScalarFunction): @staticmethod def forward(ctx, x): return x * 5 @staticmethod def backward(ctx, d_out): f_prime = 5 return f_prime * d_out .. image:: figs/Autograd/autograd3.png :align: center For functions that take multiple arguments, `backward` returns multiple deritatives with respect to each input argument. For example, if the function computes :math:`f(x, y)`, we need to return :math:`f'_x(x, y)` and :math:`f'_y(x, y)` .. jupyter-execute:: class Mul(ScalarFunction): @staticmethod def forward(ctx, x, y): # Compute f(x, y) pass @staticmethod def backward(ctx, d_out): # Compute f'_x(x, y) * d_out, f'_y(x, y) * d_out pass .. image:: figs/Autograd/autograd4.png :align: center For :math:`f(x, y) = x + 2 \times y`, we can consult our derivative rules again and get :math:`f'_x(x, y) = 1` and :math:`f'_y(x, y) = 2`. Therefore the `backward` is .. jupyter-execute:: class AddTimes2(ScalarFunction): @staticmethod def forward(ctx, x, y): return x + 2 * y @staticmethod def backward(ctx, d_out): return d_out, 2 * d_out .. jupyter-execute:: def call(y): return AddTimes2.forward(None, 10, y) plot_function("Add Times 2: x=10", call) .. jupyter-execute:: def d_call(y): return AddTimes2.backward(None, 1)[1] plot_function("d Add Times 2: x=10", d_call) Note that `backward` works a bit different than the mathematical notation. Sometimes the function for the derivative :math:`f'(x)` depends directly on x; however, `backward` does not take :math:`x` as an argument. This is not a problem for the example functions above, but things get a bit more interesting when the derivative also depends on :math:`x` itself. This is where the context arguments `ctx` comes in. Consider a function `Square`, :math:`f(x) = x^2` that squares x and has derivative :math:`f'(x) = 2 \times x`. We write it in code as .. jupyter-execute:: class Square(ScalarFunction): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x * x @staticmethod def backward(ctx, d_out): x = ctx.saved_values f_prime = 2 * x return f_prime * d_out .. jupyter-execute:: def call(x): # You don't need to do this manually # It gets called automatically in `apply` ctx = minitorch.Context() return Square.forward(ctx, x) plot_function("Square", call) .. jupyter-execute:: def d_call(x): ctx = minitorch.Context() Square.forward(ctx, x) return Square.backward(ctx, 1) plot_function("d Square", d_call) This type of function requires to explicitly save anything in `forward` that we might need in `backward`. This is a code optimization that limits the amount of storage required by the computation process. Chain Rule ------------- .. note:: This section discusses implementation of the chain rule for univariate differentiation. Before reading, review the mathematical definition of `Chain Rule `_ . The above section gives the formula for running `backward` on one function, but what if we need to run `backward` on two functions in sequence? .. image:: figs/Autograd/chain1.png :align: center We can do so using the univariate chain rule: .. math :: f'_x(g(x)) = g'(x) \times f'_{g(x)}(g(x)) If the notation gets a bit hard to follow here, naming each part may be easier to understand: .. math:: \begin{eqnarray*} y &=& g(x) \\ d_{out} &=& f'(y) \\ f'_x(g(x)) &=& g'(x) \times d_{out} \\ \end{eqnarray*} The above derivative function tells us to compute the derivative of the first function (:math:`g`), and then times the derivative of the second function (:math:`f`) with respect to the output the first function. Here is where the perspective of thinking of functions as boxes comes in handy: .. image:: figs/Autograd/chain2.png :align: center It shows that the :math:`d_{out}` multiplier passed to `backward` of the first box (left) should be the value returned by `backward` of the second box. A similar approach works for two-argument functions: .. math:: \begin{eqnarray*} f'_x(g(x, y)) &=& g_x'(x, y) \times f'_{g(x, y)}(g(x, y)) \\ f'_y(g(x, y)) &=& g_y'(x, y) \times f'_{g(x, y)}(g(x, y)) \end{eqnarray*} Or .. math:: \begin{eqnarray*} z &=& g(x, y) \\ d_{out} &=& f'(z) \\ f'_x(g(x, y)) &=& g_x'(x, y) \times d_{out} \\ f'_y(g(x, y)) &=& g_y'(x, y) \times d_{out} \end{eqnarray*} This shows that the second box (:math:`f`) does not care how many arguments the first box (:math:`g`) has, as long as it passes back :math:`d_{out}` which is enough for the chain rule to work.