Modules
from minitorch import Module, Parameter
Researchers disagree on exactly what the term deep learning means,
but one aspect that everyone agrees on is that deep models are big
and complex. Models can include hundreds of millions of learned
parameters
. In order to work with such complex systems, it is
important to have data structures which abstract away the complexity
so that it is easier to access and manipulate specific components,
and group together shared regions. These structures are not
rigorous mathematical objects, but a convenient way of managing
complex systems.
On the programming side, Modules
have become a popular paradigm to
group parameters together to make them easy to manage, access, and
address. There is nothing specific to machine learning about this
setup (and everything in MiniTorch could be done without modules), but they
make life easier and code more organized.
For now, do not worry about what parameters are for, just that we would
like to group and name them in a convenient way.
In Torch, modules are a hierarchical data structure. Each module
stores three things: 1) parameters, 2) user data, 3) other
modules. Internally, the user interacts with each of these on
self
, but under the hood the module sorts everything into
three types.
Let's work through an example of module.
class OtherModule(Module):
pass
class MyModule(Module):
def __init__(self):
# Must initialize the super class!
super().__init__()
# Type 1, a parameter.
self.parameter1 = Parameter(15)
# Type 2, user data
self.data = 25
# Type 3. another Module
self.sub_module = OtherModule()
Internally Module partitions these elements. Parameters (type 1)
are stored in a parameters dictionary, user data (type 2) is stored on
self
, modules (type 3) are stored in a modules dictionary.
This is a bit tricky. Python is a very dynamic language and allows
us to override simple things like assignment.
Be careful. All subclasses must begin their initialization by calling
super().__init__()
.
This line allows the module to capture any members of type Module
or Parameter
.
The benefit of this behavior is that it allows us to easily extract
all parameters and subparameters from modules. Specifically we can
get out all of a modules parameters using the named_parameters
function. This function returns a dictionary of all of the
parameters in the module and in all descendent sub-modules.
MyModule().named_parameters()
[('parameter1', 15)]
The names here refer to the keys in the dictionary which give the path to each parameter in the tree (similar to python dot notation). Critically this function does not just return the current module's parameters, but recursively collects parameters from all the modules below as well.
Here is an example of how you can create a tree of modules and then extract the flattened parameters
class Module1(Module):
def __init__(self):
super().__init__()
self.p1 = Parameter(5)
self.a = Module2()
self.b = Module3()
class Module2(Module):
def __init__(self):
super().__init__()
self.p2 = Parameter(10)
class Module3(Module):
def __init__(self):
super().__init__()
self.c = Module4()
class Module4(Module):
def __init__(self):
super().__init__()
self.p3 = Parameter(15)
Module1().named_parameters()
[('p1', 5), ('a.p2', 10), ('b.c.p3', 15)]
Modules can also be used to find all submodules.
Module1().modules()
[Module2(), Module3( (c): Module4() )]
Additionally, a module can have a mode
indicating how it is
currently operated. The mode should propagate to all of its
child modules. For simplicity, we only consider the train and eval mode.
module1 = Module1()
module1.train()
module1.training
True
module1.a.training
True
module1.eval()
module1.training
False