draw_boxes(["", "$G(x)$"], [1])
Trick: Pretend G is actually many different scalar functions.
$$ G(x) = [G^1(x), G^2(x), \ldots, G^N(x)] $$
$$G'^1_{x_1}([x_1, x_2]) $$ $$G'^1_{x_2}([x_1, x_2]) $$ $$G'^2_{x_1}([x_1, x_2]) $$ $$G'^2_{x_2}([x_1, x_2]) $$
$$G([x_1, x_2]) = [x_1, x_1 x_2]$$ $$G'^1_{x_1}([x_1, x_2]) = 1$$ $$G'^1_{x_2}([x_1, x_2]) = 0$$ $$G'^2_{x_1}([x_1, x_2]) = x_2$$ $$G'^2_{x_2}([x_1, x_2]) = x_1$$
$f(G(x))$
Derivatives
$$G'^1_{x_1}([x_1, x_2]) = 1$$ $$G'^1_{x_2}([x_1, x_2]) = 0$$ $$G'^2_{x_1}([x_1, x_2]) = x_2$$ $$G'^2_{x_2}([x_1, x_2]) = x_1$$
$$f'_{x_1}(G([x_1, x_2])) = d_1 \times 1 + d_2 \times x_2 $$ $$f'_{x_2}(G([x_1, x_2])) = d_1 \times 0 + d_2 \times x_1 $$
class MyFun(minitorch.Function):
def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(x)
return minitorch.tensor([x[0], x[0] * x[1]])
def backward(ctx, d: Tensor) -> Tensor:
x, = ctx.saved_values
return minitorch.tensor([d[0] * 1 + d[1] * x[1], d[1] * x[0]])
$$G^{'1}_{x_1}([x_1, x_2])$$ $$G^{'2}_{x_1}([x_1, x_2])$$ $$G^{'1}_{x_2}([x_1, x_2])$$ $$G^{'2}_{x_2}([x_1, x_2])$$
$$f'_{x_1}(G([x_1, x_2]))$$ $$f'_{x_2}(G([x_1, x_2]))$$
def label(t, d):
return vstrut(0.5) // latex(t) // vstrut(0.5) // d
opts = ArrowOpts(arc_height=-0.5)
opts2 = ArrowOpts(arc_height=-0.2)
d = hcat(
[
label("$$", matrix(3, 2, "a")),
left_arrow,
label("$g'(x)$", matrix(3, 2, "b")),
label("$d$", matrix(3, 2, "g", colormap=lambda i, j: drawing.papaya)),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts2).connect(
("b", 1, 0), ("a", 1, 0), opts2
).connect(("g", 0, 0), ("b", 0, 0), opts).connect(("g", 1, 0), ("b", 1, 0), opts)
Let $G$ be a zip of multiplication.
$G([x_1, x_2], [y_1, y_2]) = [x_1 * x_2, y_1 * y_2]$$
$$G^{'1}_{x_1}(x, y)$$ $$G^{'2}_{x_1}(x, y)$$ $$G^{'1}_{y_1}(x, y)$$ $$G^{'2}_{y_1}(x, y)$$
$$f'_{x_1}(G(x, y))$$ $$f'_{y_2}(G(x, y))$$
d = hcat(
[
matrix(3, 2, "a1"),
matrix(3, 2, "a"),
left_arrow,
matrix(3, 2, "b"),
matrix(3, 2, "g", colormap=lambda i, j: drawing.papaya),
],
1,
)
d.connect(("b", 0, 0), ("a", 0, 0), opts).connect(
("b", 1, 0), ("a", 1, 0), opts
).connect(("g", 0, 0), ("b", 0, 0), opts)