Module 1.3 - Backprop

Module 1.3

Backprop

Functions

  • Function \(f(x) = x \times 5\)

  • Implementation

    class TimesFive(ScalarFunction):
    
        @staticmethod
        def forward(ctx, x):
            return x * 5
  • \(x\) is unwrapped (python number) and return is a number

Multi-arg Functions

Code

class Mul(ScalarFunction):
    @staticmethod
    def forward(ctx, x, y):
        return x * y
_images/autograd2.png

Context

Arguments to backward must be saved in context.

class Square(ScalarFunction):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * x

    @staticmethod
    def backward(ctx, d_out):
        x = ctx.saved_values
        f_prime = 2 * x
        return f_prime * d_out

Picture

_images/autograd2.png _images/autograd3.png

Lecture Quiz

Quiz

Outline

  • Variables and Functions

  • Backward

  • Chain Rule

How do we get derivatives?

  • Base case: compute derivatives for single functions

  • Inductive case: define how to propagate a derivative

Chain Rule

Python Details

  • Use apply for the above Functions

    x = minitorch.Scalar(10.)
    z = TimesFive.apply(x)
    out = TimesFive.apply(z)
  • Apply unwraps, calls, and wraps again

Chaining Boxes

Chaining

x = minitorch.Scalar(10., name="x")
g_x = G.apply(x)
f_g_x = F.apply(g_x)
_images/chain1.png

Chain Rule

  • Compute derivative from chain

\[f'_x(g(x)) = g'(x) \times f'_{g(x)}(g(x))\]
_images/chain2.png

Chain Rule

\[\begin{split}\begin{eqnarray*} z &=& g(x) \\ d_{out} &=& f'(z) \\ f'_x(g(x)) &=& g'(x) \times d_{out} \\ \end{eqnarray*}\end{split}\]
_images/chain2.png

Example: Chain Rule

Chaining

\[log(x)^2\]
_images/chain1.png
\[\begin{split}f(z) = z^2\\ g(x) = log(x)\\\end{split}\]

Example: Chain Rule

\[\begin{split}f'(z) = 2z \times 1 \\ g'(x) = 1 / x\end{split}\]
_images/chain2.png

What is the combination?

\[f'_x(g(x))\]

Two Arguments: Chain

\[f(g(x, y))\]

Two Arguments: Chain

\[\begin{split} \begin{eqnarray*} f'_x(g(x, y)) &=& g_x'(x, y) \times f'_{g(x, y)}(g(x, y)) \\ f'_y(g(x, y)) &=& g_y'(x, y) \times f'_{g(x, y)}(g(x, y)) \end{eqnarray*}\end{split}\]
_images/chain3.png

Two Arguments: Chain

\[\begin{split}\begin{eqnarray*} z &=& g(x, y) \\ d_{out} &=& f'(z) \\ f'_x(g(x, y)) &=& g_x'(x, y) \times d_{out} \\ f'_y(g(x, y)) &=& g_y'(x, y) \times d_{out} \end{eqnarray*}\end{split}\]
_images/chain3.png

Example: Chain Rule

Chaining

\[(x * y)^2\]
_images/chain3.png
\[\begin{split}f(z) = z^2\\ g(x, y) = (x * y)\\\end{split}\]

Example: Chain Rule

\[\begin{split}f'(z) = 2z \times 1\\ g'_x(x, y) = y \\ g'_y(x, y) = x \\\end{split}\]

What is the combination?

\[\begin{split}f'_x(g(x, y)) = 2 z y\\ f'_y(g(x, y)) = 2 z x \\\end{split}\]

Multivariable Chain

\[f(g(x), h(x))\]

Multivariable Chain

\[\begin{split}\begin{eqnarray*} z_1 &=& g(x) \\ z_2 &=& h(x) \\ f'_x(g(x), h(y)) &=& g'(x) \times f'_{z_1}(z_1, z_2) + h'(x) \times f'_{z_2}(z_1, z_2) \end{eqnarray*}\end{split}\]

Coding Derivatives

  • For each \(f\) or \(g\) we need to also provide \(f'\) and \(g'\)

  • This part can be done through local symbolic or numeric differentiation

_images/autograd3.png

Picture

_images/autograd3.png

Backpropagation

Goal

  • Efficient implementation of chain-rule

  • Assume access to the graph.

  • Goal: Call backward once per variable / d_out

Full Graph

\[\begin{split}\begin{eqnarray*} z &=& x \times y \\ h(x, y) &=& \log(z) + \exp(z) \end{eqnarray*}\end{split}\]
_images/backprop1.png

Tool

If we have:

  • the derivative with respect to a Variable

  • the Function that created the Variable

We can apply the chain rule through that function.

Step

_images/backprop3.png _images/backprop4.png

Issue

Order matters!

  • If we proceed without finishing a variable, we may need to apply chain rule multiple times

Desired property: all derivatives for a variable before backward.

Ordering Step

  • Do not process any Variable until all downstream Variables are done.

  • Collect a list of the Variables first.

Topological Sorting

Topological Sorting

Pseudocode

visit(last)

function visit(node n)
  if n has a mark then return

  for each node m with an edge from n to m do
      visit(m)

  mark n with a permanent mark
  add n to list

Backpropagation

  • Graph propagation

  • Ensure flow to original Variables.

Terminology

  • Leaf: Variable created from scratch

  • Non-Leaf: Variable created with a Function

  • Constant: Term passed in that is not a variable

Algorithm: Outer Loop

  1. Call topological sort

  2. Create dict of Variables and derivatives

  3. For each node in backward order:

Algorithm: Inner Loop

  1. if Variable is leaf, add its final derivative

  2. if the Variable is not a leaf,

    1. call backward with \(d_{out}\)

    2. loop through all the Variables+derivative

    3. accumulate derivatives for the Variable

Example

_images/backprop1.png

Example

_images/backprop2.png

Example

_images/backprop3.png

Example

_images/backprop4.png

Example

_images/backprop5.png

Example

_images/backprop6.png

Example

_images/backprop7.png

Neural Networks

Neural Networks

  • New model

  • Uses repeated linear splits of data

  • Produces non-linear separators

  • Loss will not change

Training

model = Network()
...
model.named_parameters()
  • All the parameters in model are leaf Variables

  • Computing backward on loss fills their derivative

Math View

\[\begin{split}\begin{eqnarray*} h_ 1 &=& \text{ReLU}(x_1 \times w^0_1 + x_2 \times w^0_2 + b^0) \\ h_ 2 &=& \text{ReLU}(x_1 \times w^1_1 + x_2 \times w^1_2 + b^1)\\ m(x_1, x_2) &=& h_1 \times w_1 + h_2 \times w_2 + b \end{eqnarray*}\end{split}\]
Parameters:

\(w_1, w_2, w^0_1, w^0_2, w^1_1, w^1_2, b, b^0, b^1\)

Math View (Alt)

\[\begin{split}\begin{eqnarray*} \text{lin}(x; w, b) &=& x_1 \times w_1 + x_2 \times w_2 + b \\ h_ 1 &=& \text{ReLU}(\text{lin}(x; w^0, b^0)) \\ h_ 2 &=& \text{ReLU}(\text{lin}(x; w^1, b^1))\\ m(x_1, x_2) &=& \text{lin}(h; w, b) \end{eqnarray*}\end{split}\]
Parameters:

\(w_1, w_2, w^0_1, w^0_2, w^1_1, w^1_2, b, b^0, b^1\)

Code

  • Code in run_scalar.py

Code

  • Optim to move the parameters.

Q&A