Module 1.3 - Backprop¶

Functions¶

  • Function $f(x) = x \times 5$
In [2]:
class TimesFive(ScalarFunction):
    @staticmethod
    def forward(ctx, x: float) -> float:
        return x * 5
In [3]:
draw_boxes(["", ""], [1], lr=True)
Out[3]:

Multi-arg Functions¶

  • Function $f(x, y) = x \times y$
In [4]:
class Mul(ScalarFunction):
    @staticmethod
    def forward(ctx, x: float, y: float) -> float:
        return x * y
In [5]:
draw_boxes([("", ""), ""], [1], lr=True)
Out[5]:

Context¶

$$f(x) = x^2$$ $$f'(x) = 2 \times x$$

In [6]:
class Square(ScalarFunction):
    @staticmethod
    def forward(ctx: Context, x: float) -> float:
        ctx.save_for_backward(x)
        return x * x

    @staticmethod
    def backward(ctx: Context, d: float) -> Tuple[float, float]:
        (x,) = ctx.saved_values
        f_prime = 2 * x
        return f_prime * d

Box for Function¶

In [7]:
draw_boxes(["", ""], [1], lr=False)
Out[7]:

Computational Graph¶

In [8]:
draw_boxes([("", ""), "", ("", ""), ""], [1, 2, 1], lr=True)
Out[8]:

Forward Graph¶

In [9]:
def expression():
    x = Scalar(1.0)
    y = Scalar(1.0)
    z = (sum([1, x, x * x, 65]) * x * y + 6 + x) * y + 10.0 * x

    return z
In [10]:
SVG(make_graph(expression(), lr=True))
Out[10]:
No description has been provided for this image

Lecture Quiz¶

Outline¶

  • Chain Rule
  • Backpropagation

Chain Rule¶

Graph Structure¶

In [11]:
x = Scalar(2.0)
x_2 = Square.apply(x)
print(x_2.history)
ScalarHistory(last_fn=<class '__main__.Square'>, ctx=Context(no_grad=False, saved_values=(2.0,)), inputs=[Scalar(2.0)])
In [12]:
print(x_2.history.inputs[0].history)
ScalarHistory(last_fn=None, ctx=None, inputs=())

Derivative¶

In [13]:
x = Scalar(2.0)
x_2 = Square.apply(x)
x_3 = Square.apply(x_2)
x_3.backward()
print(x.derivative)
32.0
In [14]:
draw_boxes(["", "", ""], [1, 1], lr=True).center_xy().scale_uniform_to_x(
    1
).with_envelope(chalk.rectangle(1, 0.3))
Out[14]:

Chain Rule¶

Compute derivative from chain

$$f(g(x)) = f(z)$$

In [15]:
draw_boxes(["x", "z", "f(z)"], [1, 1], lr=True).center_xy().scale_uniform_to_x(
     1
).with_envelope(chalk.rectangle(1, 0.3))
Out[15]:

Chain Rule¶

Compute derivative from chain

$$f'_x(g(x)) = g'(x) \times f'_{g(x)}(g(x))$$

Chain Rule¶

$$ \begin{eqnarray*} z &=& g(x) \\ d &=& f'(z) \\ f'_x(g(x)) &=& g'(x) \times d \\ \end{eqnarray*} $$

In [16]:
draw_boxes(
    [r"$d\cdot g'(x)$", "$f'(z)$", "$1$"], [1, 1], lr=False
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Out[16]:

Example: Chain Rule¶

$$log(x)^2$$

$$\begin{eqnarray*} f(z) &= z^2\\ g(x) &= \log(x)\\ \end{eqnarray*} $$

Example: Chain Rule¶

$$ \begin{eqnarray*} f'(z) &= 2z \times 1 \\ g'(x) &= 1 / x \end{eqnarray*} $$

What is the combination?

$$f'_x(g(x))$$

Example: Chain Rule¶

$$((x)^2)^2$$

$$ \begin{eqnarray*} f(z) &= z^2\\ g(x) &= x^2\\ \\ f'(z) &= 2\times z\\ g'(x) &= 2 \times x \\ \end{eqnarray*} $$

Example: Chain Rule¶

$$ \begin{eqnarray*} f'_x(g(x)) &= 2 \times x \times 2 \times x^2 = 4 x^3\\ \end{eqnarray*} $$

Two Arguments: Chain¶

$$f(g(x, y))$$

In [17]:
draw_boxes(
    [("", ""), "", ("", ""), ""], [1, 2, 1], lr=True
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Out[17]:

Two Arguments: Chain¶

$$ \begin{eqnarray*} f'_x(g(x, y)) &=& g_x'(x, y) \times f'_{g(x, y)}(g(x, y)) \\ \\ \\ f'_y(g(x, y)) &=& g_y'(x, y) \times f'_{g(x, y)}(g(x, y)) \end{eqnarray*} $$

Two Arguments: Chain¶

$$ \begin{eqnarray*} z &=& g(x, y) \\ d &=& f'(z) \\ f'_x(g(x, y)) &=& g_x'(x, y) \times d_{out} \\ f'_y(g(x, y)) &=& g_y'(x, y) \times d_{out} \end{eqnarray*} $$

In [18]:
draw_boxes(
    [(r"$d \cdot  g'_x(x, y)$", r"$d \cdot g'_y(x, y)$"), "$f'(z)$", "$1$"],
    [1, 1],
    lr=False,
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Out[18]:

Example: Chain Rule¶

$$(x \times y)^2$$

$$ \begin{eqnarray*} f(z) &= z^2\\ g(x, y) &= (x \times y)\\ \end{eqnarray*} $$

Example: Chain Rule¶

$$ \begin{eqnarray*} f'(z) = 2z \times 1\\ g'_x(x, y) = y \\ g'_y(x, y) = x \\ \end{eqnarray*} $$ What is the combination?

$$ \begin{eqnarray*} f'_x(g(x, y)) &= 2 z y\\ f'_y(g(x, y)) &= 2 z x \\ \end{eqnarray*} $$

Multivariable Chain¶

$$f(g(x), g(x))$$

In [19]:
draw_boxes(
    ["$x$", ("$z_1 = g(x)$", "$z_2 = g(x)$"), ""], [1, 1]
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Out[19]:

Multivariable Chain¶

$$ \begin{aligned} \begin{eqnarray*} d &=& 1 \times f'_{z_1}(z_1, z_2) + 1 \times f'_{z_2}(z_1, z_2) \\ h'_x(x) &=& d \times g'_x(x) \\ \end{eqnarray*} \end{aligned}$$

In [20]:
draw_boxes(
    [r"$d \cdot g'_x(x)$", ("$f'_{z_1}(z_1, z_2)$", "$f'_{z_2}(z_1, z_2)$"), "$h1$"],
    [1, 1],
    lr=False,
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Out[20]:

Backpropagation¶

Complex Graphs¶

In [21]:
def expression():
    x = Scalar(1.0, name="x")
    y = Scalar(1.0, name="y")
    z = -y * sum([x, x, x]) * y + 10.0 * x
    return z + z
In [22]:
SVG(make_graph(expression(), lr=True))
Out[22]:
No description has been provided for this image

Goal¶

  • Efficient implementation of chain-rule

  • Assume access to the graph.

  • Goal: Call backward once per variable

Full Graph¶

$$ \begin{eqnarray*} z &=& x \times y \\ h(x, y) &=& \log(z) + \exp(z) \end{eqnarray*} $$

In [23]:
backprop(1)
Out[23]:

Tool¶

If we have:

  • the derivative with respect to a scalar
  • the function last called on the scalar

We can apply the chain rule through that function.

Step¶

In [24]:
backprop(3)
Out[24]:
In [25]:
backprop(4)
Out[25]:

Issue¶

Order matters!

  • If we proceed without finishing a variable, we may need to apply chain rule multiple times

Desired property: all derivatives for a variable before backward.

Ordering Step¶

  • Do not process any Variable until all downstream Variables are done.

  • Collect a list of the Variables first.

Topological Sorting¶

  • Topological Sorting
  • High-level -> Run depth first search and mark nodes.

Topological Sorting¶

  visit(last)

  function visit(node n)
    if n has a mark then return

    for each node m with an edge from n to m do
        visit(m)

    mark n with a permanent mark
    add n to list

Topological Sorting¶

In [26]:
def expression():
    x = Scalar(1.0, name="x")
    y = Scalar(1.0, name="y")
    z = sum([x, x, x]) * y + 10.0 * x
    return z + z
In [27]:
SVG(make_graph(expression(), lr=True))
Out[27]:
No description has been provided for this image

Backpropagation¶

  • Graph propagation
  • Ensure flow to original Variables.

Terminology¶

  • Leaf: Variable created from scratch
  • Non-Leaf: Variable created with a Function
  • Constant: Term passed in that is not a variable

Algorithm: Outer Loop¶

  1. Call topological sort
  2. Create dict of Variables and derivatives
  3. For each node in backward order:

Algorithm: Inner Loop¶

1. if Variable is leaf, add its final derivative
2. if the Variable is not a leaf,
  A. call backward with $d$
  B. loop through all the Variables+derivative
  C. accumulate derivatives for the Variable

Example¶

In [28]:
backprop(1)
Out[28]:

Example¶

In [29]:
backprop(2)
Out[29]:

Example¶

In [30]:
backprop(3)
Out[30]:

Example¶

In [31]:
backprop(4)
Out[31]:

Example¶

In [32]:
backprop(5)
Out[32]:

Example¶

In [33]:
backprop(6)
Out[33]:

Example¶

In [34]:
backprop(7)
Out[34]:

QA¶