Fusing Operations¶
import chalk
from mt_diagrams.tensor_draw import (
color,
matrix,
image_matmul_simple,
image_matmul_full,
)
chalk.set_svg_draw_height(400)
chalk.set_svg_height(200)
Another approach we can use to help improve tensor computations is to specialize commonly used combinations of operators. Combining operations can eliminate unnecessary intermediate tensors. This is particularly helpful for saving memory.
There is a lot of ongoing work on how to do this operator fusion automatically. However, for very common operators, it is worth just writing these operators directly and speeding them up.
In minitorch we will do this fusion by customizing special operations. If we find it useful, we will add these operations in Numba
Example: Matrix Multiplication¶
Let's consider a matrix multiplication example. Recall that the rows of the first matrix interact with the columns of the second.
chalk.hcat(
[
matrix(2, 4, colormap=lambda i, j: color(1, 4)(0, j)),
matrix(4, 3, colormap=lambda i, j: color(1, 4)(0, i)),
],
0.5,
)
matmul = image_matmul_simple()
matmul
In past modules we
have done matrix multiplication by applying a broadcasted zip
and
then a reduce
. For example, consider a tensor of size (2, 4) and
a tensor of size (4, 3). We first zip these together with broadcasting
to produce a tensor of size (2, 4, 3):
image_matmul_full()
And then reduce it to a tensor of size (2, 1, 3), which we can view as (2, 3):
matrix(2, 3)
This computation is correct, but it has two issues. First, it needs
to create a tensor of size (2,4,3) in the intermediate step. These
tensor can be an order of magnitude larger than either of the tensors
that were passed in. Another more subtle problem is that the reduce
operator may need to call save_for_backwards
on this tensor, which
means it stays in memory throughout all of our forward computations.
An alternative option is to fuse together these two operations. We
will create a single matmul operations using python @
operator. We
can then skip computing the intermediate value and directly compute
the output. We can do this by writing a specialized tensor Function
with a forward
function that directly produces the output and a
backward
that produces the required gradient.
This allows us to specialize our implementation of matrix multiplication. We do this by 1) walking over the output indices, 2) seeing which indices are reduced, and 3) then seeing which were part of the zip.
Given how important this operator is, it is worth spending the time to think about how to make each of these steps fast. Once we know the output index we can very efficiently walk through the input indices and accumulate their sums.
A fast matrix multiplication operator can be used to compute the
forward
. To compute backward we need to reverse the operations that
we computed. In particular this requires zipping grad out with the
alternative input matrix. We can see the backward
step by tracing
the forward arrows:
matmul
It would be a bit annoying to optimize this code as well, but again
luckily, to compute the backward
step , we can reuse our forward
optimization. In fact we can simply use the following identity from
matrix calculus that tells us backward can be computed as a transpose
and another matrix multiply.
\begin{eqnarray*} f(M, N) &=& M N \\ g'_M(f(M, N)) &=& d N^T \\ g'_N(f(M, N)) &=& M^T d \end{eqnarray*}