Module 1.3 - Backprop¶
Functions¶
- Function $f(x) = x \times 5$
class TimesFive(ScalarFunction):
@staticmethod
def forward(ctx, x: float) -> float:
return x * 5
draw_boxes(["", ""], [1], lr=True)
Multi-arg Functions¶
- Function $f(x, y) = x \times y$
class Mul(ScalarFunction):
@staticmethod
def forward(ctx, x: float, y: float) -> float:
return x * y
draw_boxes([("", ""), ""], [1], lr=True)
Context¶
$$f(x) = x^2$$ $$f'(x) = 2 \times x$$
class Square(ScalarFunction):
@staticmethod
def forward(ctx: Context, x: float) -> float:
ctx.save_for_backward(x)
return x * x
@staticmethod
def backward(ctx: Context, d: float) -> Tuple[float, float]:
(x,) = ctx.saved_values
f_prime = 2 * x
return f_prime * d
Box for Function¶
draw_boxes(["", ""], [1], lr=False)
Computational Graph¶
draw_boxes([("", ""), "", ("", ""), ""], [1, 2, 1], lr=True)
Forward Graph¶
def expression():
x = Scalar(1.0)
y = Scalar(1.0)
z = (sum([1, x, x * x, 65]) * x * y + 6 + x) * y + 10.0 * x
return z
SVG(make_graph(expression(), lr=True))
Lecture Quiz¶
Outline¶
- Chain Rule
- Backpropagation
Chain Rule¶
Graph Structure¶
x = Scalar(2.0)
x_2 = Square.apply(x)
print(x_2.history)
ScalarHistory(last_fn=<class '__main__.Square'>, ctx=Context(no_grad=False, saved_values=(2.0,)), inputs=[Scalar(2.0)])
print(x_2.history.inputs[0].history)
ScalarHistory(last_fn=None, ctx=None, inputs=())
Derivative¶
x = Scalar(2.0)
x_2 = Square.apply(x)
x_3 = Square.apply(x_2)
x_3.backward()
print(x.derivative)
32.0
draw_boxes(["", "", ""], [1, 1], lr=True).center_xy().scale_uniform_to_x(
1
).with_envelope(chalk.rectangle(1, 0.3))
draw_boxes(["x", "z", "f(z)"], [1, 1], lr=True).center_xy().scale_uniform_to_x(
1
).with_envelope(chalk.rectangle(1, 0.3))
Chain Rule¶
$$ \begin{eqnarray*} z &=& g(x) \\ d &=& f'(z) \\ f'_x(g(x)) &=& g'(x) \times d \\ \end{eqnarray*} $$
draw_boxes(
[r"$d\cdot g'(x)$", "$f'(z)$", "$1$"], [1, 1], lr=False
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Example: Chain Rule¶
$$log(x)^2$$
$$\begin{eqnarray*} f(z) &= z^2\\ g(x) &= \log(x)\\ \end{eqnarray*} $$
Example: Chain Rule¶
$$ \begin{eqnarray*} f'(z) &= 2z \times 1 \\ g'(x) &= 1 / x \end{eqnarray*} $$
What is the combination?
$$f'_x(g(x))$$
Example: Chain Rule¶
$$((x)^2)^2$$
$$ \begin{eqnarray*} f(z) &= z^2\\ g(x) &= x^2\\ \\ f'(z) &= 2\times z\\ g'(x) &= 2 \times x \\ \end{eqnarray*} $$
Example: Chain Rule¶
$$ \begin{eqnarray*} f'_x(g(x)) &= 2 \times x \times 2 \times x^2 = 4 x^3\\ \end{eqnarray*} $$
Two Arguments: Chain¶
$$f(g(x, y))$$
draw_boxes(
[("", ""), "", ("", ""), ""], [1, 2, 1], lr=True
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Two Arguments: Chain¶
$$ \begin{eqnarray*} f'_x(g(x, y)) &=& g_x'(x, y) \times f'_{g(x, y)}(g(x, y)) \\ \\ \\ f'_y(g(x, y)) &=& g_y'(x, y) \times f'_{g(x, y)}(g(x, y)) \end{eqnarray*} $$
Two Arguments: Chain¶
$$ \begin{eqnarray*} z &=& g(x, y) \\ d &=& f'(z) \\ f'_x(g(x, y)) &=& g_x'(x, y) \times d_{out} \\ f'_y(g(x, y)) &=& g_y'(x, y) \times d_{out} \end{eqnarray*} $$
draw_boxes(
[(r"$d \cdot g'_x(x, y)$", r"$d \cdot g'_y(x, y)$"), "$f'(z)$", "$1$"],
[1, 1],
lr=False,
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Example: Chain Rule¶
$$(x \times y)^2$$
$$ \begin{eqnarray*} f(z) &= z^2\\ g(x, y) &= (x \times y)\\ \end{eqnarray*} $$
Example: Chain Rule¶
$$ \begin{eqnarray*} f'(z) = 2z \times 1\\ g'_x(x, y) = y \\ g'_y(x, y) = x \\ \end{eqnarray*} $$ What is the combination?
$$ \begin{eqnarray*} f'_x(g(x, y)) &= 2 z y\\ f'_y(g(x, y)) &= 2 z x \\ \end{eqnarray*} $$
Multivariable Chain¶
$$f(g(x), g(x))$$
draw_boxes(
["$x$", ("$z_1 = g(x)$", "$z_2 = g(x)$"), ""], [1, 1]
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Multivariable Chain¶
$$ \begin{aligned} \begin{eqnarray*} d &=& 1 \times f'_{z_1}(z_1, z_2) + 1 \times f'_{z_2}(z_1, z_2) \\ h'_x(x) &=& d \times g'_x(x) \\ \end{eqnarray*} \end{aligned}$$
draw_boxes(
[r"$d \cdot g'_x(x)$", ("$f'_{z_1}(z_1, z_2)$", "$f'_{z_2}(z_1, z_2)$"), "$h1$"],
[1, 1],
lr=False,
).center_xy().scale_uniform_to_x(1).with_envelope(chalk.rectangle(1, 0.3))
Backpropagation¶
Complex Graphs¶
def expression():
x = Scalar(1.0, name="x")
y = Scalar(1.0, name="y")
z = -y * sum([x, x, x]) * y + 10.0 * x
return z + z
SVG(make_graph(expression(), lr=True))
Goal¶
Efficient implementation of chain-rule
Assume access to the graph.
Goal: Call backward once per variable
Full Graph¶
$$ \begin{eqnarray*} z &=& x \times y \\ h(x, y) &=& \log(z) + \exp(z) \end{eqnarray*} $$
backprop(1)
Tool¶
If we have:
- the derivative with respect to a scalar
- the function last called on the scalar
We can apply the chain rule through that function.
Step¶
backprop(3)
backprop(4)
Issue¶
Order matters!
- If we proceed without finishing a variable, we may need to apply chain rule multiple times
Desired property: all derivatives for a variable before backward.
Ordering Step¶
Do not process any Variable until all downstream Variables are done.
Collect a list of the Variables first.
Topological Sorting¶
- Topological Sorting
- High-level -> Run depth first search and mark nodes.
Topological Sorting¶
visit(last)
function visit(node n)
if n has a mark then return
for each node m with an edge from n to m do
visit(m)
mark n with a permanent mark
add n to list
Topological Sorting¶
def expression():
x = Scalar(1.0, name="x")
y = Scalar(1.0, name="y")
z = sum([x, x, x]) * y + 10.0 * x
return z + z
SVG(make_graph(expression(), lr=True))
Backpropagation¶
- Graph propagation
- Ensure flow to original Variables.
Terminology¶
- Leaf: Variable created from scratch
- Non-Leaf: Variable created with a Function
- Constant: Term passed in that is not a variable
Algorithm: Outer Loop¶
- Call topological sort
- Create dict of Variables and derivatives
- For each node in backward order:
Algorithm: Inner Loop¶
1. if Variable is leaf, add its final derivative
2. if the Variable is not a leaf,
A. call backward with $d$
B. loop through all the Variables+derivative
C. accumulate derivatives for the Variable
Example¶
backprop(1)
Example¶
backprop(2)
Example¶
backprop(3)
Example¶
backprop(4)
Example¶
backprop(5)
Example¶
backprop(6)
Example¶
backprop(7)