Functional
Externally, MiniTorch supports the standard Torch API, which allows Python users to develop in a standard Python coding style. Internally, the library uses a functional-style. This approach is preferred for two reasons: first, it makes it easy to test, and secondly it makes it easy to optimize. While this style requires a bit of extra thought to understand, it has some benefits.
Primarily we will use the functional style to define higher-order,
functions. These are functions that take functions as arguments
and return functions as results. Python defines a special type
for these: Callable
.
from typing import Callable, Iterable
Any function can be turned into a variable of type callable. (Although this in itself is not very interesting).
def add(a: float, b: float) -> float:
return a + b
v: Callable[[float, float], float] = add
def mul(a: float, b: float) -> float:
return a * b
v: Callable[[float, float], float] = mul
It is interesting though, to pass functions as arguments to other functions. For example, we can pass a callable to a function that uses it.
def combine3(
fn: Callable[[float, float], float], a: float, b: float, c: float
) -> float:
return fn(fn(a, b), c)
print(combine3(add, 1, 3, 5))
print(combine3(mul, 1, 3, 5))
9 15
We can also use this approach to create new functions as reutrn arguments.
def combine3(fn):
def new_fn(a: float, b: float, c: float) -> float:
return fn(fn(a, b), c)
return new_fn
add3: Callable[[float, float, float], float] = combine3(add)
mul3: Callable[[float, float, float], float] = combine3(mul)
print(add3(1, 3, 5))
9
As an extended example, let's create a higher-order version of the filter function. Filter should take a list and return only the values that are true under a function.
def filter(fn: Callable[[float], bool]) -> Callable[[Iterable[float]], Iterable[float]]:
def apply(ls: Iterable[float]):
ret = []
for x in ls:
if fn(x):
ret.append(x)
return ret
return apply
We then use this to create a new function.
def more_than_4(x: float) -> bool:
return x > 4
filter_for_more_than_4 = filter(more_than_4)
filter_for_more_than_4([1, 10, 3, 5])
[10, 5]
Functional programming can be elegant, but also hard. When in doubt remember that you can always write things out in a simpler form first and then check that you get the same sensible result.