Tensors¶
import chalk
import mt_diagrams.drawing as drawing
from chalk import hcat, vcat
from mt_diagrams.tensor_draw import color, matrix, tensor
chalk.set_svg_draw_height(300)
Tensor is a fancy name for a simple concept. A tensor is a multi-dimensional array of arbitrary dimensions. It is a convenient and efficient way to hold data, which becomes much more powerful when paired with fast operators and autodifferentiation.
Tensor Shapes¶
So far we have focused on scalars, which correspond to 0-dimensional tensors. Next, we consider a 1-dimensional tensor (vector):
matrix(5, 1)
Then a 2-dimensional tensor (matrix):
matrix(5, 4)
In addition to its dimension (dims
), other critical aspects of a tensor
are its
shape
and size
. The shape of the above vector is (5,)) and its size
(i.e. number of squares in the graph) is 5.
The shape of the above matrix is (4,4) and its size is 20.
A three-dimensional tensor with shape (2, 3, 3) and size 18 looks like this:
tensor(0.75, 2, 3, 3)
We access an element of the tensor by tensor index notation: tensor[i]
for 1-dimension,
tensor[i, j]
for 2-dimension, tensor[i, j, k]
for 3-dimension, and so
forth. For example,
tensor[0, 1, 2]
would give this blue cube:
tensor(
0.75,
2,
3,
3,
colormap=lambda i, j, k: drawing.aqua if (i, j, k) == (0, 1, 2) else drawing.white,
)
Typically, we access tensors just like multi-dimensional arrays, but there are some special geometric properties that make tensors different.
First, tensors make it easy to change the order of the
dimensions. For example, we can transpose
the dimensions of a matrix. For
a general tensor, we refer to this operation as permute
. Calling
permute
arbitrarily reorders the dimensions of the input tensor.
For example, as shown below, calling permute(1,0)
on a matrix of shape (2, 5)
gives a matrix of shape (5, 2). For indexing into the permuted matrix,
we access elements using tensor[j, i]
instead of tensor[i, j]
.
hcat(
[
matrix(2, 5, colormap=color(2, 5)),
matrix(5, 2, colormap=lambda i, j: color(2, 5)(j, i)),
],
1,
)
Second, tensors make it really easy to add or remove additional dimensions. Note that a matrix of shape (5, 2) can store the same amount of data as a matrix of shape (1, 5, 2), so they have the same size as shown below:
vcat(
[
matrix(2, 5, "", color(2, 5)).center_xy(),
tensor(0.75, 1, 2, 5, "", lambda i, j, k: color(2, 5)(j, k)).center_xy(),
],
1,
)
We would like to easily increase or decrease the dimension of a tensor
without changing the data. We will do this with a view
function: use
view(1, 5, 2)
for the above example. Element tensor[i, j]
in the (5,2)
matrix
is now tensor[0, i, j]
in the 3-dimensional tensor.
Critically, neither of these operations changes anything about the
input tensor itself. Both view
and permute
are tensor tricks
,
i.e. operations
that only modify how we look at the tensor, but not any of its
data. Another way to say this is that they do not move or copy the
data in any way, but only the external tensor wrapper.
Tensor Strides¶
Users of a Tensor library only have to be aware of the shape
and
size
of a tensor. However, there are important implementation details
that we need to keep track of. To make our code a bit cleaner, we
need to separate out the internal tensor data
from the user-facing
tensor. In addition to the shape
, :class:minitorch.TensorData
manages tensor storage
and strides
:
- Storage is where the core data of the tensor is kept. It is always a
1-D array of numbers of length
size
, no matter the dimensionality orshape
of the tensor. Keeping a 1-D storage allows us to have tensors with different shapes point to the same type of underlying data.
- Strides is a tuple that provides the mapping from user indexing
to the position in the 1-D
storage
.
Strides
can get a bit confusing to think about, so let's go over an example.
Consider a matrix of shape (5, 2). The standard mapping is to walk
left-to-right,
top-to-bottom to order this matrix to the 1-D storage
:
d = vcat(
[
matrix(5, 2, "n", colormap=color(5, 2)),
matrix(1, 10, "s", colormap=color(1, 10)),
],
1,
)
d.connect(("n", 3, 0), ("s", 0, 6)).connect(("n", 3, 1), ("s", 0, 7))
We call it contiguous
mapping, since it is in the natural counting order (bigger strides left).
Here the strides are $(2, 1)$. We read this as each column moves 1
step in storage and each row
moves 2 steps. We can have different strides for the same shape. For
instance, if
we were walking top-to-bottom, left-to-right, we would have the following
stride map:
d = vcat(
[
matrix(5, 2, "n", colormap=color(5, 2)),
matrix(1, 10, "s", colormap=lambda i, j: color(5, 2)(j % 5, j // 5)),
]
)
d.connect(("n", 3, 0), ("s", 0, 3)).connect(("n", 3, 1), ("s", 0, 8))
Contiguous strides are generally preferred, but non-contiguous strides can be quite useful as well. Consider transposing the above matrix and using strides (1,2):
d = vcat(
[
matrix(2, 5, "n", colormap=lambda i, j: color(5, 2)(j, i)),
matrix(1, 10, "s", colormap=color(1, 10)),
]
)
d.connect(("n", 0, 3), ("s", 0, 6)).connect(("n", 1, 3), ("s", 0, 7))
It has new strides (1,2) and new shape (2,5), in contrast to the previous (2,1)
stride map on the (5,2) matrix. But notably no change in the
storage
. This is one of the super powers of tensors mentioned above:
we can easily manipulate how we view the same underlying storage
.
Strides naturally extend to higher-dimensional tensors.
d = vcat(
[
tensor(0.5, 2, 2, 3, "n", colormap=lambda i, j, k: color(4, 3)(i * 2 + j, k)),
matrix(1, 12, "s", colormap=color(1, 12)),
],
1,
)
d.connect(("n", 0, 1, 1), ("s", 0, 4)).connect(("n", 1, 0, 2), ("s", 0, 2 + 6))
Finally, strides can be used to implement indexing into the tensor.
Assuming strides are $(s_1, s_2)$ and we want to look up
tensor[i, j]
, we can directly use strides to find its postion in the
storage
::
storage[s1 * i + s2 * j]
Or in general::
storage[s1 * index1 + s2 * index2 + s3 * index3 ... ]