viewopts = ArrowOpts(arc_height=0.5, shaft_style=astyle)
opts2 = ArrowOpts(arc_height=0.2, shaft_style=astyle)
d = hcat([matrix(3, 2, "a"), matrix(3, 2, "b"), right_arrow, matrix(3, 2, "c")], 1)
d.connect(("a", 0, 0), ("c", 0, 0), opts).connect(
("a", 1, 0), ("c", 1, 0), opts
).connect(("b", 0, 0), ("c", 0, 0), opts2).connect(("b", 1, 0), ("c", 1, 0), opts2)
opts = ArrowOpts(arc_height=0.5, shaft_style=astyle)
opts2 = ArrowOpts(arc_height=0.2, shaft_style=astyle)
d = hcat([matrix(3, 1, "a"), matrix(1, 2, "b"), right_arrow, matrix(3, 2, "c")], 1)
d.connect(("a", 0, 0), ("c", 0, 0), opts).connect(
("a", 1, 0), ("c", 1, 1), opts
).connect(("b", 0, 0), ("c", 0, 0), opts2).connect(("b", 0, 1), ("c", 1, 1), opts2)
def col(c):
return (
matrix(3, 2).line_color(c).align_t().align_l()
+ matrix(3, 1).align_t().align_l()
).center_xy()
hcat(
[
matrix(3, 2),
col(drawing.white),
right_arrow,
matrix(3, 2),
col(drawing.papaya),
right_arrow,
matrix(3, 2),
],
0.4,
)
d, r, c = 2, 3, 2
d = draw_equation(
[
t(1, r, c),
t(d, 1, c),
None,
t(d, r, c, highlight=True) + t(1, r, c),
t(d, r, c, highlight=True) + t(d, 1, c),
None,
t(d, r, c, n="s"),
None,
t(d, r, 1),
None,
matrix(3, 2),
]
)
connect(
d,
[("s", 0, 0, 1), ("s", 0, 0, 0)],
[("s", 0, 1, 1), ("s", 0, 1, 0)],
[("s", 0, 2, 1), ("s", 0, 2, 0)],
)
#
What is backward?
x = minitorch.rand((4, 5), requires_grad=True)
y = minitorch.rand((4, 5), requires_grad=True)
z = x * y
z.sum().backward()
draw_boxes(["", "$G(x)$"], [1])
$$G([x_1, x_2, x_3]) = x_1 x_2 x_3$$
$$G'_{x_1}([x_1, x_2, x_3]) = x_2 x_3$$ $$G'_{x_2}([x_1, x_2, x_3]) = x_1 x_3$$ $$G'_{x_3}([x_1, x_2, x_3]) = x_1 x_2$$
The gradient is a tensor of derivatives.
$$G'([x_1, x_2, x_3]) = [z/x_1, z/x_2, z/x_3]$$ $$z = x_1 x_2 x_3 $$
Original $G$ tensor-to-scalar. Gradient $G'$ tensor-to-tensor.
$$f(G([x_1, x_2, x_3]))$$ $$d = f'(z)$$
$$f'_{x_1}(G([x_1, x_2, x_3])) = x_2 x_3 d$$ $$f'_{x_2}(G([x_1, x_2, x_3])) = x_1 x_3 d$$ $$f'_{x_3}(G([x_1, x_2, x_3])) = x_1 x_2 d$$
class Prod3(minitorch.Function):
def forward(ctx, x: Tensor) -> Tensor:
prod = x[0] * x[1] * x[2]
ctx.save_for_backward(prod, x)
return prod
def backward(ctx, d: Tensor) -> Tensor:
prod, x = ctx.saved_values
return d * prod / x
Trick: Pretend G is actually many different scalar functions.
$$ G(x) = [G^1(x), G^2(x), \ldots, G^N(x)] $$
draw_boxes(["$x$", ("", ""), "$f(G(x))$"], [2, 1])
$f(G(x))$
$$G([x_1, x_2]) = [x_1, x_1 x_2]$$
$$G([x_1, x_2]) = [x_1, x_1 x_2]$$ $$G'^1_{x_1}([x_1, x_2]) = 1$$ $$G'^1_{x_2}([x_1, x_2]) = 0$$ $$G'^2_{x_1}([x_1, x_2]) = x_2$$ $$G'^2_{x_2}([x_1, x_2]) = x_1$$
$$f'_x(G(x))$$ $$d_1 = f'(z_1)$$ $$d_2 = f'(z_2)$$
$$f'_{x_1}(G([x_1, x_2])) = d_1 \times 1 + d_2 \times x_2 $$ $$f'_{x_2}(G([x_1, x_2])) = d_2 \times x_1 $$
class MyFun(minitorch.Function):
def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(x)
return minitorch.tensor([x[0], x[0] * x[1]])
def backward(ctx, d: Tensor) -> Tensor:
x, = ctx.saved_values
return minitorch.tensor([d[0] * 1 + d[1] * x[1], d[1] * x[0]])
All of this is just notation for scalars
Can often reason about it with scalars directly
$G^{'i}_{x_j}([x_1, \ldots, x_N])$ ?
def label(t, d):
return vstrut(0.5) // latex(t) // vstrut(0.5) // d
opts = ArrowOpts(arc_height=-0.5)
opts2 = ArrowOpts(arc_height=-0.2)
d = hcat(
[
label("$$", matrix(3, 2, "a")),
left_arrow,
label("$g'(x)$", matrix(3, 2, "b")),
label("$d$", matrix(3, 2, "g", colormap=lambda i, j: drawing.papaya)),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts2).connect(
("b", 1, 0), ("a", 1, 0), opts2
).connect(("g", 0, 0), ("b", 0, 0), opts).connect(("g", 1, 0), ("b", 1, 0), opts)
class Neg(minitorch.ScalarFunction):
@staticmethod
def forward(ctx, a: float) -> float:
return -a
@staticmethod
def backward(ctx, d: float) -> float:
return -d
class Neg(minitorch.Function):
@staticmethod
def forward(ctx, t1: Tensor) -> Tensor:
return t1.f.neg_map(t1)
@staticmethod
def backward(ctx, d: Tensor) -> Tensor:
return d.f.neg_map(d)
class Inv(minitorch.Function):
@staticmethod
def forward(ctx, t1: Tensor) -> Tensor:
ctx.save_for_backward(t1)
return t1.f.inv_map(t1)
@staticmethod
def backward(ctx, d: Tensor) -> Tensor:
(t1,) = ctx.saved_values
return d.f.inv_back_zip(t1, d)
$G^{'i}_{x_j}(x, y)$ ?
d = hcat(
[
matrix(3, 2, "a1"),
matrix(3, 2, "a"),
left_arrow,
matrix(3, 2, "b"),
matrix(3, 2, "g", colormap=lambda i, j: drawing.papaya),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts).connect(
("b", 1, 0), ("a", 1, 0), opts
).connect(("g", 0, 0), ("b", 0, 0), opts)
class Add(minitorch.Function):
@staticmethod
def forward(ctx, t1: Tensor, t2: Tensor) -> Tensor:
return t1.f.add_zip(t1, t2)
@staticmethod
def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
return grad_output, grad_output
d = hcat(
[
matrix(3, 2, "a"),
left_arrow,
matrix(3, 1, "g", colormap=lambda i, j: drawing.papaya),
],
1,
)
d.connect(("g", 0, 0), ("a", 0, 0), opts).connect(("g", 0, 0), ("a", 0, 1), opts)
class Sum(minitorch.Function):
@staticmethod
def forward(ctx, a: Tensor, dim: Tensor) -> Tensor:
ctx.save_for_backward(a.shape, dim)
return a.f.add_reduce(a, int(dim.item()))
@staticmethod
def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, float]:
a_shape, dim = ctx.saved_values
return grad_output, 0.0