Module 2 - Tensors

_images/stride4.png

We now have a fully developed autodifferentiation system built around scalars. This system is correct, but you saw during training that it is inefficient. Every scalar number requires building an object, and each operation requires storing a graph of all the values that we have previously created. Training requires repeating the above operations, and running models, such as a linear model, requires a for loop over each of the terms in the network.

This module introduces and implements a tensor object that will solve these problems. Tensors group together many repeated operations to save Python overhead and to pass off grouped operations to faster implementations.

All starter code is available in https://github.com/minitorch/Module-2 .

To begin, remember to activate your virtual environment first, and then clone your assignment:

>>> git clone {{STUDENT_ASSIGNMENT2_URL}}
>>> cd {{STUDENT_ASSIGNMENT_NAME}}
>>> pip install -Ue .

You need the files from previous assignments, so maker sure to pull them over to your new repo.

Tasks

For this module we have implemented the skeleton tensor.py file for you. This is a subclass of Variable that is very similar to scalar.py from the last assignment. Before starting, it is worth reading through this file to have a sense of what a Tensor Variable does. Each of the following tasks ask you to implement the methods this file relies on:

  • tensor_data.py : Indexing, strides, and storage

  • tensor_ops.py : Higher-order tensor operations

  • tensor_functions.py : Autodifferentiation-ready functions

Tasks 2.1: Tensor Data - Indexing

Note

This task requires familiarity with tensor indexing. Be sure to first carefully read the Guide on Tensors. You may also find it helpful to read tutorials on using tensors/arrays in Torch or NumPy.

The MiniTorch library implements the core tensor backend as minitorch.TensorData. This class handles indexing, storage, transposition, and low-level details such as strides. You will first implement these core functions before turning to the user-facing class minitorch.Tensor.

Todo

Complete the following functions in minitorch/tensor_data.py, and pass tests marked as task2_1.

minitorch.index_to_position(index, strides)

Converts a multidimensional tensor index into a single-dimensional position in storage based on strides.

Parameters
  • index (array-like) -- index tuple of ints

  • strides (array-like) -- tensor strides

Returns

position in storage

Return type

int

minitorch.TensorData.permute(self, *order)

Permute the dimensions of the tensor.

Parameters

order (list) -- a permutation of the dimensions

Returns

a new TensorData with the same storage and a new dimension order.

Return type

TensorData

Tasks 2.2: Tensor Operations

Note

This task requires familiarity with higher-order tensor operations. Be sure to first carefully read the Guide on Operations. You may also find it helpful to go back to Module 0 and make sure you understand higher-order functions and currying in Python.

Tensor operations apply high-level, higher-order operations to all elements in a tensor simultaneously. In particularly, you can map, zip, and reduce tensor data objects together. On top of this foundation, we can build up a Function class for Tensor, similar to what we did for the ScalarFunction. In this task, you will first implement generic tensor operations and then use them to implement forward for specific operations.

We have built a debugging tool for you to observe the workings of your expressions to see how the graph is built. You can run it in project/show_expression.py. You can alter the expression at the top of the file and then run the code to create a graph in Visdom:

## Run your tensor expression here
def expression():
   x = minitorch.tensor([10, 12], (2,), requires_grad=True)
   x.name = "x"

   z = minitorch.tensor([10, 12], (2,), requires_grad=True)
   z.name = "z"

   y = x * z + 10.0
   y.name = "y"
   return y
>>> python project/show_expression.py
_images/expgraph2.png

Todo

Add functions in minitorch/tensor_ops.py and minitorch/tensor_functions.py for each of the following, and pass tests marked as task2_2.

minitorch.tensor_map(fn)

CUDA higher-order tensor map function.

fn_map = tensor_map(fn)
fn_map(out, ... )
Parameters
  • fn -- function mappings floats-to-floats to apply.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • in_storage (array) -- storage for in tensor.

  • in_shape (array) -- shape for in tensor.

  • in_strides (array) -- strides for in tensor.

Returns

Fills in out

Return type

None

minitorch.tensor_zip(fn)

CUDA higher-order tensor zipWith (or map2) function

fn_zip = tensor_zip(fn)
fn_zip(out, ...)
Parameters
  • fn -- function mappings two floats to float to apply.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • a_storage (array) -- storage for a tensor.

  • a_shape (array) -- shape for a tensor.

  • a_strides (array) -- strides for a tensor.

  • b_storage (array) -- storage for b tensor.

  • b_shape (array) -- shape for b tensor.

  • b_strides (array) -- strides for b tensor.

Returns

Fills in out

Return type

None

minitorch.tensor_reduce(fn)

CUDA higher-order tensor reduce function.

Parameters
  • fn -- reduction function maps two floats to float.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • a_storage (array) -- storage for a tensor.

  • a_shape (array) -- shape for a tensor.

  • a_strides (array) -- strides for a tensor.

  • reduce_dim (int) -- dimension to reduce out

Returns

Fills in out

Return type

None

minitorch.TensorFunctions.Mul.forward(ctx, a, b)
minitorch.TensorFunctions.Sigmoid.forward(ctx, a)
minitorch.TensorFunctions.ReLU.forward(ctx, a)
minitorch.TensorFunctions.Log.forward(ctx, a)
minitorch.TensorFunctions.Exp.forward(ctx, a)
minitorch.TensorFunctions.LT.forward(ctx, a, b)
minitorch.TensorFunctions.EQ.forward(ctx, a, b)
minitorch.TensorFunctions.Permute.forward(ctx, a, order)

Tasks 2.3: Gradients and Autograd

Note

This task requires familiarity with tensor backward operations. Be sure to first carefully read the Guide on Tensor Variables. You may also find it helpful to go back to Module 1 and review Variables and Functions.

Similar to minitorch.Scalar, minitorch.Tensor is a Variable that supports autodifferentiation. In this task, you will implement backward functions for tensor operations.

Todo

Complete following functions in minitorch/tensor_ops.py, and pass tests marked as task2_3.

minitorch.TensorFunctions.Mul.backward(ctx, grad_output)
minitorch.TensorFunctions.Sigmoid.backward(ctx, grad_output)
minitorch.TensorFunctions.ReLU.backward(ctx, grad_output)
minitorch.TensorFunctions.Log.backward(ctx, grad_output)
minitorch.TensorFunctions.Exp.backward(ctx, grad_output)
minitorch.TensorFunctions.LT.backward(ctx, grad_output)
minitorch.TensorFunctions.EQ.backward(ctx, grad_output)
minitorch.TensorFunctions.Permute.backward(ctx, grad_output)

Tasks 2.4: Tensor Broadcasting

Note

This task requires familiarity with tensor broadcasting. Be sure to first carefully read the Guide on Broadcasting. You may also find it helpful to go through s ome broadcasting tutorials on Torch or NumPy as it is identical.

Todo

Complete following functions in minitorch/tensor_data.py and minitorch/tensor_ops.py, and pass tests marked as task2_4.

minitorch.shape_broadcast(shape1, shape2)

Broadcast two shapes to create a new union shape.

Parameters
  • shape1 (tuple) -- first shape

  • shape2 (tuple) -- second shape

Returns

broadcasted shape

Return type

tuple

Raises

IndexingError -- if cannot broadcast

minitorch.broadcast_index(big_index, big_shape, shape, out_index)

Convert a big_index into big_shape to a smaller out_index into shape following broadcasting rules. In this case it may be larger or with more dimensions than the shape given. Additional dimensions may need to be mapped to 0 or removed.

Parameters
  • big_index (array-like) -- multidimensional index of bigger tensor

  • big_shape (array-like) -- tensor shape of bigger tensor

  • shape (array-like) -- tensor shape of smaller tensor

  • out_index (array-like) -- multidimensional index of smaller tensor

Returns

Fills in out_index.

Return type

None

You need to revist the following functions implemented in task 2.2 to make sure that broadcast_index is used.

Also note that in our implementation of Function, backward is allowed to return a tensor of shape that is smaller then the input. It will automatically be broadcasted to the larger shape.

minitorch.tensor_map(fn)

CUDA higher-order tensor map function.

fn_map = tensor_map(fn)
fn_map(out, ... )
Parameters
  • fn -- function mappings floats-to-floats to apply.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • in_storage (array) -- storage for in tensor.

  • in_shape (array) -- shape for in tensor.

  • in_strides (array) -- strides for in tensor.

Returns

Fills in out

Return type

None

minitorch.tensor_zip(fn)

CUDA higher-order tensor zipWith (or map2) function

fn_zip = tensor_zip(fn)
fn_zip(out, ...)
Parameters
  • fn -- function mappings two floats to float to apply.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • a_storage (array) -- storage for a tensor.

  • a_shape (array) -- shape for a tensor.

  • a_strides (array) -- strides for a tensor.

  • b_storage (array) -- storage for b tensor.

  • b_shape (array) -- shape for b tensor.

  • b_strides (array) -- strides for b tensor.

Returns

Fills in out

Return type

None

minitorch.tensor_reduce(fn)

CUDA higher-order tensor reduce function.

Parameters
  • fn -- reduction function maps two floats to float.

  • out (array) -- storage for out tensor.

  • out_shape (array) -- shape for out tensor.

  • out_strides (array) -- strides for out tensor.

  • out_size (array) -- size for out tensor.

  • a_storage (array) -- storage for a tensor.

  • a_shape (array) -- shape for a tensor.

  • a_strides (array) -- strides for a tensor.

  • reduce_dim (int) -- dimension to reduce out

Returns

Fills in out

Return type

None

Task 2.5: Training

If your code works you should now be able to move on to the tensor training script in project/run_tensor.py. This code runs the same basic training setup as in Module 1 - Auto-Differentiation, but now utilize your tensor code.

Todo

Implement the missing forward functions in project/run_tensor.py. They should do exactly the same thing as the corresponding functions in project/run_scalar.py, but now use the tensor code base.

  • Train a tensor model and add your results for all datasets to the README.

  • Record the time per epoch reported by the trainer. (It is okay if it is slow).