```
import math
from dataclasses import dataclass
import chalk
from chalk import hcat
from colour import Color
from mt_diagrams.drawing import r
from mt_diagrams.mlprimer_draw import (
compare,
draw_graph,
draw_nn_graph,
draw_with_hard_points,
graph,
s,
s1,
s1_hard,
s2,
s2_hard,
show,
show_loss,
split_graph,
with_points,
)
import minitorch
chalk.set_svg_draw_height(300)
chalk.set_svg_height(300)
```

# ML Primer¶

This guide is a primer on the very basics of machine learning that are
necessary to complete the assignments and motivate the final
system. Machine learning is a rich and well-developed field with many
different models, goals, and learning settings. There are many great
texts that cover all the aspects of
the area in detail. This
guide is not
that. Our goal is to explain the minimal details of *one* dataset
with *one* class of model. Specifically, this is an introduction to
supervised binary classification with neural networks. The goal of this
section
is to learn how a basic neural network works to classify simple points.

## Dataset¶

Supervised learning problems begin with a labeled `training`

dataset.
We assume that we are given a set of labeled points. Each point has
two coordinates $x_1$ and $x_2$, and has a label $y$
corresponding to an O or X. For instance, here is one O labeled point:

```
d = hcat([split_graph([s1[0]], []), split_graph([s1[1]], [])], 0.3)
r(d, "figs/Graphs/data1.svg")
```

And here is an X labeled point.

```
d = hcat([split_graph([], [s2[0]]), split_graph([], [s2[1]])], 0.3)
r(d, "figs/Graphs/data2.svg")
```

It is often convenient to plot all of the points together on one set of axes.

```
d = split_graph(s1, s2)
r(d, "figs/Graphs/data3.svg")
```

Here we can see that all the X points are in the top-right and all the O points are on the bottom-left. Not all datasets is this simple, and here is another dataset where points are split up a bit more.

```
d = split_graph(s1_hard, s2_hard)
r(d, "figs/Graphs/data4.svg")
```

Later in the class, we will consider datasets of different forms, e.g. a dataset of handwritten numbers, where some are 8's and others are 2's:

Here is an example of what this dataset looks like.

## Model¶

Our ML system needs to specify a model that we want to the data. A model is a function that assigns labels to data points. We can specify a model in Python through its parameters and function.

```
@dataclass
class Linear:
# Parameters
w1: float
w2: float
b: float
def forward(self, x1: float, x2: float) -> float:
return self.w1 * x1 + self.w2 * x2 + self.b
```

This model can be written mathematically as,

$$m(x_1, x_2; w_1, w_2, b) = x_1 \times w_1 + x_2 \times w_2 + b$$.

We call it a linear model because it divides the data points up based on a line. We can visualize this be computing the "decision boundary", i.e. the areas where this function returns a positive and negative boundary.

```
model = Linear(1, 1, -0.9)
```

```
d = draw_graph(model)
r(d, "figs/Graphs/model1.svg")
```

We can overlay the simple dataset described earlier over this model. This tells us roughly how well the model fits this dataset.

```
d = show(model)
r(d, "figs/Graphs/incorrect.svg")
```

Models can take many different forms, Here is another model which has a compound form. We will discuss these types of models more below. It splits its decision into three regions (Model B).

```
@dataclass
class Split:
m1: Linear
m2: Linear
def forward(self, x1, x2):
return self.m1.forward(x1, x2) * self.m2.forward(x1, x2)
```

```
model_b = Split(Linear(1, 1, -1.5), Linear(1, 1, -0.5))
```

```
d = draw_graph(model_b)
r(d, "figs/Graphs/model2.svg")
```

Models may also have strange shapes and even disconnected regions. Any blue/red split will do, for instance (Model C):

```
@dataclass
class Part:
def forward(self, x1, x2):
return 1 if (0.0 <= x1 < 0.5 and 0.0 <= x2 < 0.6) else 0
```

```
d = draw_graph(Part())
r(d, "figs/Graphs/model3.svg")
```

## Parameters¶

Once we have decided on the shape that we are using, we need a way to move between models in that class. Ideally, we would have internal knobs that alter the properties of the model.

```
show(Linear(1, 1, -0.5))
```

```
show(Linear(1, 1, -1))
```

In the case of the linear models, there are two knobs,

a. rotating the separator

```
model1 = Linear(1, 1, -1.0)
model2 = Linear(0.5, 1.5, -1.0)
d = compare(model1, model2)
r(d, "figs/Graphs/weight.svg")
```

b. changing the separator cutoff

```
model1 = Linear(1, 1, -1.0)
model2 = Linear(1, 1, -1.5)
d = compare(model1, model2)
r(d, "figs/Graphs/bias.svg")
```