Module 3 - Efficiency¶
In addition to helping simplify code, tensors provide a basis for speeding up computation. In fact, they are really the only way to efficiently write deep learning code in a slow language like Python. However, nothing we have done so far really makes anything faster than Module 0 - Fundamentals. This module is focused on taking advantage of tensors to write fast code, first on standard CPUs and then using GPUs.
All starter code is available in https://github.com/minitorch/Module-3 .
To begin, remember to activate your virtual environment first, and then clone your assignment:
>>> git clone {{STUDENT_ASSIGNMENT3_URL}}
>>> cd {{STUDENT_ASSIGNMENT_NAME}}
>>> pip install -Ue .
You need the files from previous assignments, so maker sure to pull them over to your new repo.
Be sure to continue to follow the Contributing guidelines.
Tasks¶
For this assignment you will need to run you commands in the Google Colab virtual environment. Follow these instructions for Colab setup.
Task 3.1: Parallelization¶
Note
This task requires basic familiarity with Numba prange. Be sure to very carefully read the section on Parallel Computation, Numba and review Module 2 - Tensors.
The main backend for our codebase are the three functions map, zip, and reduce. If we can speed up these three, everything we built so far will get better. This exercise asks you to utilize Numba and the njit function to speed up these functions. In particular if you can utilize parallelization through prange you can get some big wins. Be careful though! Parallelization can lead to funny bugs.
In order to help debug this code, we have created a parallel analytics script for you
python project/parallel_test.py
Running this script will run NUMBA diagnostics on your functions.
Todo
Complete the following in minitorch/fast_ops.py and pass tests marked as task3_1. Furthermore include the diagnostics output from the above script in your README. Your code should have at least one parallelized loop.
-
minitorch.fast_ops.
tensor_map
(fn)¶ NUMBA low_level tensor map function.
- Parameters
fn -- function mappings floats-to-floats to apply.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
in_storage (array) -- storage for in tensor.
in_shape (array) -- shape for in tensor.
in_strides (array) -- strides for in tensor.
- Returns
Fills in out
- Return type
None
-
minitorch.fast_ops.
tensor_zip
(fn)¶ NUMBA higher-order tensor zip function
fn_zip = tensor_zip(fn) fn_zip(out, ...)
Fill in the out array by applying fn to each value of a_storage and b_storage assuming a_shape and b_shape broadcast to out_shape.
- Parameters
fn -- function maps two floats to float to apply.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
a_storage (array) -- storage for a tensor.
a_shape (array) -- shape for a tensor.
a_strides (array) -- strides for a tensor.
b_storage (array) -- storage for b tensor.
b_shape (array) -- shape for b tensor.
b_strides (array) -- strides for b tensor.
- Returns
Fills in out
- Return type
None
-
minitorch.fast_ops.
tensor_reduce
(fn)¶ NUMBA higher-order tensor reduce function.
- Parameters
fn -- reduction function mapping two floats to float.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
a_storage (array) -- storage for a tensor.
a_shape (array) -- shape for a tensor.
a_strides (array) -- strides for a tensor.
reduce_shape (array) -- shape of reduction (1 for dimension kept, shape value for dimensions summed out)
reduce_size (int) -- size of reduce shape
- Returns
Fills in out
- Return type
None
Task 3.2: Matrix Multiplication¶
Note
This task requires basic familiarity with matrix multiplication. Be sure to read the Guide on Fusing Operations.
Matrix multiplication is key to all of the models that we have trained so far. In the last module, we computed matrix multiplication using broadcasting. In this task, we ask you to implement it directly as a function. Do your best to make the function efficient, but for now all that matters is that you correctly produce a multiply function that passes our tests and has some parallelism.
In order to use this function, you will also need to add a new MatMul Function to tensor_functions.py. We have added a version in the starter code you can copy. You might also find it useful to add a slow broadcasted matrix_multiply to tensor_ops.py for debugging.
In order to help debug this code, we have created a parallel analytics script for you
python project/parallel_test.py
Running this script will run NUMBA diagnostics on your functions.
After you finish this task, you may want to skip to 3.5 and experiment with training on the real task under speed conditions.
Todo
Complete the following function in minitorch/fast_ops.py and copy the MatMul Function to tensor_functions.py. Pass tests marked as task3_2.
-
minitorch.fast_ops.
tensor_matrix_multiply
(out, out_shape, out_strides, a_storage, a_shape, a_strides, b_storage, b_shape, b_strides)¶ NUMBA tensor matrix multiply function.
Should work for any tensor shapes that broadcast as long as
assert a_shape[-1] == b_shape[-2]
- Parameters
out (array) -- storage for out tensor
out_shape (array) -- shape for out tensor
out_strides (array) -- strides for out tensor
a_storage (array) -- storage for a tensor
a_shape (array) -- shape for a tensor
a_strides (array) -- strides for a tensor
b_storage (array) -- storage for b tensor
b_shape (array) -- shape for b tensor
b_strides (array) -- strides for b tensor
- Returns
Fills in out
- Return type
None
Task 3.3: CUDA Operations¶
Note
This task requires basic familiarity with CUDA. Be sure to read the Guide on GPU Programming and the Numba CUDA guide.
We can do even better than parallelization if we have access to specialized hardware. This task asks you to build a GPU implementation of the backend operations. It will be hard to equal what PyTorch does, but if you are clever you can make these computations really fast (aim for 2x of task 3.1).
Todo
Complete the following functions in minitorch/cuda_ops.py, and pass the tests marked as task3_3.
-
minitorch.cuda_ops.
tensor_map
(fn)¶ CUDA higher-order tensor map function.
fn_map = tensor_map(fn) fn_map(out, ... )
- Parameters
fn -- function mappings floats-to-floats to apply.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
out_size (array) -- size for out tensor.
in_storage (array) -- storage for in tensor.
in_shape (array) -- shape for in tensor.
in_strides (array) -- strides for in tensor.
- Returns
Fills in out
- Return type
None
-
minitorch.cuda_ops.
tensor_zip
(fn)¶ CUDA higher-order tensor zipWith (or map2) function
fn_zip = tensor_zip(fn) fn_zip(out, ...)
- Parameters
fn -- function mappings two floats to float to apply.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
out_size (array) -- size for out tensor.
a_storage (array) -- storage for a tensor.
a_shape (array) -- shape for a tensor.
a_strides (array) -- strides for a tensor.
b_storage (array) -- storage for b tensor.
b_shape (array) -- shape for b tensor.
b_strides (array) -- strides for b tensor.
- Returns
Fills in out
- Return type
None
-
minitorch.cuda_ops.
tensor_reduce
(fn)¶ CUDA higher-order tensor reduce function.
- Parameters
fn -- reduction function maps two floats to float.
out (array) -- storage for out tensor.
out_shape (array) -- shape for out tensor.
out_strides (array) -- strides for out tensor.
out_size (array) -- size for out tensor.
a_storage (array) -- storage for a tensor.
a_shape (array) -- shape for a tensor.
a_strides (array) -- strides for a tensor.
reduce_dim (int) -- dimension to reduce out
- Returns
Fills in out
- Return type
None
Task 3.4: CUDA Matrix Multiplication¶
Note
This task requires basic familiarity with CUDA. Be sure to read the Guide on GPU Programming and the Numba CUDA guide.
Finally we can combine both these approaches and implement CUDA matmul. This operation is probably the most important in all of deep learning and is central to making models fast. Again, we first strive for accuracy, but, the faster you can make it, the better.
Todo
Implement minitorch/cuda_ops.py with CUDA, and pass tests marked as task3_4.
-
minitorch.cuda_ops.
tensor_matrix_multiply
(*args)¶ CUDA Kernel object. When called, the kernel object will specialize itself for the given arguments (if no suitable specialized version already exists) & compute capability, and launch on the device associated with the current context.
Kernel objects are not to be constructed by the user, but instead are created using the
numba.cuda.jit()
decorator.
Task 3.4b: Extra Credit¶
Implementing matrix multiplication and reduction efficiently is hugely important for many deep learning tasks. We have seen one method for implementing these functions, but there are many more optimizations that you could apply.
For extra credit, first read these two tutorials:
Then implement a version of CUDA tensor_reduce or tensor_matrix_multiply that take advantage of other aspects of the GPU.
To get full credit, you should document your code to show us that you understand each line. Prove to us that these lead to speed-ups on large matrix operations by making a graph comparing them to naive operations.
Task 3.5: Training¶
If your code works, you should now be able to move on to the tensor training script in project/run_fast_tensor.py. This code is the same basic training setup as Module 2 - Tensors, but now utilizes your fast tensor code. We have left the matmul layer blank for you to implement with your tensor code.
Todo
Implement the missing functions in project/run_fast_tensor.py. These should do exactly the same thing as the corresponding functions in project/run_tensor.py, but now use the faster backend
Train a tensor model and add your results for all dataset to the README.
Run a bigger model and record the time per epoch reported by the trainer. Here is the command
python run_fast_tensor.py --BACKEND gpu --HIDDEN 100 --DATASET split --RATE 0.05 python run_fast_tensor.py --BACKEND cpu --HIDDEN 100 --DATASET split --RATE 0.05
Train a tensor model and add your results for all three dataset to the README. Also record the time per epoch reported by the trainer. (As a reference, our parallel implementation gave a 10x speedup).
On a standard Colab GPU setup, aim for you CPU to get below 2 seconds per epoch and GPU to be below 1 second per epoch. (With some cleverness you can do much better.)