mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 01:40:44 +00:00
2416 lines
76 KiB
Plaintext
2416 lines
76 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"hide_input": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"from fastai.gen_doc.nbdoc import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "raw",
|
|
"metadata": {},
|
|
"source": [
|
|
"[[chapter_foundations]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# A neural net from the foundations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This chapter begins a journey where we will go from the very basics and dig inside what was hidden in the models we used in the previous chapters. We will be covering many of the same things we've seen before, but this time around we'll be looking much more closely at the implementation details, and much less closely at the practical issues of how and why things are as they are.\n",
|
|
"\n",
|
|
"We will build everything from scratch, only using basic indexing into a tensor. We write a neural net from the foundations, then we will implement our own backpropagation from scratch, so we'll know what is happening in PyTorch when we do `loss.backward()`. We'll also see how to extend PyTorch with custom *autograd* functions that allow you to specify your own forward and backward computations."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## A neural net from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's start by refreshing our understanding of how matrix multiplication is used in a basic neural network. Since we're building everything up from scratch, we'll use nothing but plain Python initially (except for indexing into PyTorch tensors), and then replace the plain Python with PyTorch functionality once we've seen how to create it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Modeling a neuron"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"A neuron receives a given number of inputs and has an internal weight for each of them. It then sums those weighted inputs to produce an output and add an inner bias. In math, this can be written:\n",
|
|
"\n",
|
|
"$$ out = \\sum_{i=1}^{n} x_{i} w_{i} + b$$\n",
|
|
"\n",
|
|
"if we name our inputs $(x_{1},\\dots,x_{n})$, our weights $(w_{1},\\dots,w_{n})$ and our bias $b$. In code this translates into:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"output = sum([x*w for x,w in zip(inputs,weights)]) + bias\n",
|
|
"```\n",
|
|
"\n",
|
|
"This output is then fed into a non-linear function before being sent to another neuron called an *activation function*, and the most common function used in Deep Learning for this the *Rectified Linear Unit* or *ReLU*, which, as we've seen, is a fancy way of saying\n",
|
|
"```python\n",
|
|
"def relu(x): return x if x >= 0 else 0\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"A Deep Learning model is then built by stacking a lot of those neurons in successive layers. We create a first layer with a certain number of neurons (usually called *hidden size*) and link all the inputs to each of those neurons. Such a layer is often called *fully connected layer* or a *dense layer* (for densely connected) or a *linear layer*. \n",
|
|
"\n",
|
|
"If you have done a little bit of linear algebra, you may remember than when you have a lot of:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"sum([x*w for x,w in zip(input,weight)])\n",
|
|
"```\n",
|
|
"\n",
|
|
"...for each `input` in our batch and the `weight` of each neuron, it's the equivalent of one *matrix multiplication*. More precisely, if our inputs are in a matrix `x` which is `batch_size` by `n_inputs`, and if we have grouped the weights of our neurons in a matrix `w` which is `n_neurons` by `n_inputs` (each neuron must have the same number of weights as they have inputs) and all the biases in a vector `b` of size `n_neurons`, then the output of this fully connected layer is\n",
|
|
"\n",
|
|
"```python\n",
|
|
"y = x @ w.t() + b\n",
|
|
"```\n",
|
|
"\n",
|
|
"where `@` represents the matrix product and `w.t()` is the transpose matrix of `w`. The output `y` is then of size `batch_size` by `n_neurons` and in position `(i,j)`, we have (for the mathy folks out there):\n",
|
|
"\n",
|
|
"$$y_{i,j} = \\sum_{k=1}^{n} x_{i,k} w_{k,j} + b_{j}$$\n",
|
|
"\n",
|
|
"or in code:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"y[i,j] = sum([a * b for a,b in zip(x[i,:],w[j,:])]) + b[j]\n",
|
|
"```\n",
|
|
"\n",
|
|
"The transpose is necessary because in the mathematical definition of the matrix product `m @ n`, the coefficient `(i,j)` is:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"sum([a * b for a,b in zip(m[i,:],n[:,j])])\n",
|
|
"```\n",
|
|
"\n",
|
|
"So the very basic operation we need is a matrix multiplication, as it's what is hidden in the core of a neural net."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Matrix multiplication from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's write a function that computes the matrix product of two tensors, before we allow ourselves to use the PyTorch version of it. We will only use the indexing in PyTorch tensors."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import tensor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll need three nested for loops: one for the row indices, one for the column indices and one for the inner sum. `ac`, `ar` stand for number of columns of `a`, number of rows of `a` respectively (same convention for `b`) and we make sure the matrix product is possible by checking that `a` has as many columns as `b` has rows."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul(a,b):\n",
|
|
" ar,ac = a.shape # n_rows * n_cols\n",
|
|
" br,bc = b.shape\n",
|
|
" assert ac==br\n",
|
|
" c = torch.zeros(ar, bc)\n",
|
|
" for i in range(ar):\n",
|
|
" for j in range(bc):\n",
|
|
" for k in range(ac): c[i,j] += a[i,k] * b[k,j]\n",
|
|
" return c"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To test this out, we'll pretend (using random matrices) that we're working with a small batch of 5 MNIST images, flattened into `28*28` vectors, and a linear model to turn them into 10 activations:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"m1 = torch.randn(5,28*28)\n",
|
|
"m2 = torch.randn(784,10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's time our function, using the Jupyter \"magic\" `%time`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 1.15 s, sys: 4.09 ms, total: 1.15 s\n",
|
|
"Wall time: 1.15 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%time t1=matmul(m1, m2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"...and how does that compare to PyTorch's builtin?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"14 µs ± 8.95 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 20 t2=m1@m2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As we can see, in Python three nested loops is a very bad idea! Python is a slow language, and this isn't going to be very efficient. We see here that PyTorch is around 100,000 times faster than Python--and that's before we even start using the GPU!\n",
|
|
"\n",
|
|
"Where does this difference come from? That's because PyTorch didn't write its matrix multiplication in Python but in C++ to make it fast. In general, whenever we do some computations on tensors, we will need to *vectorize* them so that we can take advantage of the speed of PyTorch, usually by using two techniques: elementwise arithmetic and broadcasting. \n",
|
|
"\n",
|
|
"We will show how to do this on our example of matrix multiplication."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"### Elementwise arithmetic"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"All the basic operators (+,-,\\*,/,>,<,==) can be applied element-wise. That means if we write `a+b` for two tensors `a` and `b` that have the same shape, we will get a tensor with the sums of one element of `a` with one element of `b."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([12., 14., 3.])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"a = tensor([10., 6, -4])\n",
|
|
"b = tensor([2., 8, 7])\n",
|
|
"a + b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The booleans operators will return an array of booleans:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([False, True, True])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"a < b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"If we want to know if every element of `a` is less than the corresponding element in `b`, or if two tensors are equals, we need to combine those elementwise operations with `torch.all`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(False), tensor(False))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"(a < b).all(), (a==b).all()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that reduction operations (that returns only one element) like `all()`, `sum()` or `mean()` return tensors with only one element calles rank-0 tensors. If you want to convert it to a plain Python boolean or number, you need to call `.item()`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"9.666666984558105"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"(a + b).mean().item()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The elementwise operations work on tensors of any ranks, as long as they have the same shape."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[ 1., 4., 9.],\n",
|
|
" [16., 25., 36.],\n",
|
|
" [49., 64., 81.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
|
|
"m*m"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"However you can't have element-wise operations of tensors that don't have the same shape (unless they are broadcastable, see below)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "RuntimeError",
|
|
"evalue": "The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m<ipython-input-12-add73c4f74e0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"n = tensor([[1., 2, 3], [4,5,6]])\n",
|
|
"m*n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With element-wise arithmetic, we can remove one of our three nested loops: we can multiply the tensors that correspond to the `i`-th row of `a` and the `j`-th column of `b` before summing all the elements, which will speed up things because the inner loop will now be executed by PyTorch at C speed. \n",
|
|
"\n",
|
|
"To access one row/column, we can simply write `a[i,:]` or `b[:,j]`. The column means take everything in that dimension. We could restrict and only take a slice on this particular dimension by passing a range like `1:5` instead of just `:`. In that case, we would take the elements in column 1 to 4 (the last part is always excluded). \n",
|
|
"\n",
|
|
"One simplification is that we can always omit trailing columns, so `a[i,:]` can be abbreviated to `a[i]`. With all of that, we can write a new version of our matrix multiplication:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul(a,b):\n",
|
|
" ar,ac = a.shape\n",
|
|
" br,bc = b.shape\n",
|
|
" assert ac==br\n",
|
|
" c = torch.zeros(ar, bc)\n",
|
|
" for i in range(ar):\n",
|
|
" for j in range(bc): c[i,j] = (a[i] * b[:,j]).sum()\n",
|
|
" return c"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.7 ms ± 88.1 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 20 t3 = matmul(m1,m2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We are already ~700 times faster, just by removing that inner for loop! And that is just the beginning. By combining this with broadcasting, we can remove another loop and get an even more important speed-up."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Broadcasting"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As we discussed in <<chapter_mnist_basics>>, broadcasting is a term introduced by the numpy library that describes how tensor of different ranks are treated during arithmetic operations. For instance, it's obvious there is no way to add a 3 by 3 matrix with a 4 by 5 matrix, but what if we want to add one scalar (which can be represented as a 1 by 1 tensor) with a matrix? Or a vector of size 3 with a 3 by 4 matrix? In both cases, we can find a way to make sense of what the operation could be.\n",
|
|
"\n",
|
|
"Broadcasting gives specific rules to codify when shapes are compatible when trying to do an element-wise operation, and how the tensor of the smaller shape is expanded to match the tensor of the bigger shape. It's essential to master those rules if you want to be able to write code that executes quickly. In this section, we'll expand our previous treatment of broadcasting to understand these rules."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Broadcasting with a scalar"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This is the easiest broadcating: when we have a tensor `a` and a scalar, we just imagine a tensor of the same shape as `a` filled with that scalar and perform the operation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([ True, True, False])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"a = tensor([10., 6, -4])\n",
|
|
"a > 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"How are we able to do this comparison? 0 is being *broadcast* to have the same dimensions as `a`. Note that this is done without creating a tensor full of zeros in memory (that would be very inefficient). \n",
|
|
"\n",
|
|
"This is very useful if you want to normalize your dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar):"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[-1.4652, -1.0989, -0.7326],\n",
|
|
" [-0.3663, 0.0000, 0.3663],\n",
|
|
" [ 0.7326, 1.0989, 1.4652]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
|
|
"(m - 5) / 2.73"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Broadcasting a vector to a matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also broadcast a vector to a matrix:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([3, 3]), torch.Size([3]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = tensor([10.,20,30])\n",
|
|
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
|
|
"m.shape,c.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[11., 22., 33.],\n",
|
|
" [14., 25., 36.],\n",
|
|
" [17., 28., 39.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"m + c"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here the elements of `c` are expanded to make three rows that match, and this way the operation is possible. Again, behind the scenes PyTorch doesn't create three copies of `c` in memory. This is done by the `expand_as` method behind the scenes:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[10., 20., 30.],\n",
|
|
" [10., 20., 30.],\n",
|
|
" [10., 20., 30.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c.expand_as(m)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"If we look at the corresponding tensor, we can ask for its `storage` property (which shows the actual contents of the memory used for the tensor) to check there is no useless data stored:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
" 10.0\n",
|
|
" 20.0\n",
|
|
" 30.0\n",
|
|
"[torch.FloatStorage of size 3]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t = c.expand_as(m)\n",
|
|
"t.storage()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Even if it has officially 9 elements, the memory used is only 3 scalars. It's possible with a clever trick by giving a *stride* of 0 on that dimension (which means that when it looks for the next row by adding the stride, it doesn't move)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"((0, 1), torch.Size([3, 3]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t.stride(), t.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Since `m` is of size 3 by 3, there were two ways to do broadcasting. The fact it was done on the last dimension is a convention that comes from the rules of broadcasting and has nothing to do with the way we ordered our tensors:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[11., 22., 33.],\n",
|
|
" [14., 25., 36.],\n",
|
|
" [17., 28., 39.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c + m"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We get the same result. In fact it's only possible to broadcast a vector of size `n` with a matrix of size `m` by `n`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[11., 22., 33.],\n",
|
|
" [14., 25., 36.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = tensor([10.,20,30])\n",
|
|
"m = tensor([[1., 2, 3], [4,5,6]])\n",
|
|
"c+m"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This won't work:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "RuntimeError",
|
|
"evalue": "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m<ipython-input-25-64bbbad4d99c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m10.\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mc\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"c = tensor([10.,20])\n",
|
|
"m = tensor([[1., 2, 3], [4,5,6]])\n",
|
|
"c+m"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"If we want to broadcast in the other dimension, we have to change the shape of our vector to make it a 3 by 1 matrix. This is done with the `unsqueeze` method in PyTorch."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([3, 3]), torch.Size([3, 1]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = tensor([10.,20,30])\n",
|
|
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
|
|
"c = c.unsqueeze(1)\n",
|
|
"m.shape,c.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And this time, `c` is expanded on the columns side."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[11., 12., 13.],\n",
|
|
" [24., 25., 26.],\n",
|
|
" [37., 38., 39.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c+m"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Like before the corresponding storage contains only three scalars."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
" 10.0\n",
|
|
" 20.0\n",
|
|
" 30.0\n",
|
|
"[torch.FloatStorage of size 3]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t = c.expand_as(m)\n",
|
|
"t.storage()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And the expanded tensor has the right shape by giving it a stride of 0 on the column dimension."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"((1, 0), torch.Size([3, 3]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t.stride(), t.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The way broadcasting works is that if we need to add dimensions, the default is to add them at the beginning. When we were broadcasting before, it was doing `c.unsqueeze(0)` behind the scenes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c = tensor([10.,20,30])\n",
|
|
"c.shape, c.unsqueeze(0).shape,c.unsqueeze(1).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `unsqueeze` command can be replaced by `None` indexing."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c.shape, c[None,:].shape,c[:,None].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"You can always omit traiiling columns, and `...` means all preceding dimensions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([1, 3]), torch.Size([3, 1]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"c[None].shape,c[...,None].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"With this, we can remove another for loop in our matrix multiplication function: instead of multiplying `a[i]` with `b[:,j]`, we can multiply `a[i]` with the whole matrix `b` using broadcasting, then sum all the results."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul(a,b):\n",
|
|
" ar,ac = a.shape\n",
|
|
" br,bc = b.shape\n",
|
|
" assert ac==br\n",
|
|
" c = torch.zeros(ar, bc)\n",
|
|
" for i in range(ar):\n",
|
|
"# c[i,j] = (a[i,:] * b[:,j]).sum() # previous\n",
|
|
" c[i] = (a[i ].unsqueeze(-1) * b).sum(dim=0)\n",
|
|
" return c"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"357 µs ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 20 t4 = matmul(m1,m2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We're now 3,700 times faster than our first implementation!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Broadcasting Rules"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When operating on two tensors, PyTorch compares their shapes element-wise. It starts with the *trailing dimensions*, and works its way backward, adding 1 when it meets empty dimensions. Two dimensions are *compatible* when\n",
|
|
"\n",
|
|
"- they are equal, or\n",
|
|
"- one of them is 1, in which case that dimension is broadcasted to make it the same size\n",
|
|
"\n",
|
|
"Arrays do not need to have the same number of dimensions. For example, if you have a `256*256*3` array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:\n",
|
|
"\n",
|
|
"```\n",
|
|
"Image (3d tensor): 256 x 256 x 3\n",
|
|
"Scale (1d tensor): (1) (1) 3\n",
|
|
"Result (3d tensor): 256 x 256 x 3\n",
|
|
"```\n",
|
|
" \n",
|
|
"However, a 2d tensor of size 256 x 256 isn't compatible with our image.\n",
|
|
"\n",
|
|
"```\n",
|
|
"Image (3d tensor): 256 x 256 x 3\n",
|
|
"Scale (1d tensor): (1) 256 x 256\n",
|
|
"Error\n",
|
|
"```\n",
|
|
"\n",
|
|
"In the first examples we had with a `3x3` matrix and vector of size `3`, broadcast is done on the rows:\n",
|
|
"\n",
|
|
"```\n",
|
|
"Matrix (2d tensor): 3 x 3\n",
|
|
"Vector (1d tensor): (1) 3\n",
|
|
"Result (2d tensor): 3 x 3\n",
|
|
"```\n",
|
|
"\n",
|
|
"As a little exercice around those rules, try to determine what dimensions to add (and where) when you need to normalize a batch of images of size `64 x 3 x 256 x 256` with vectors of three elements (one for the mean and one for the standard deviation)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Einstein summation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Before using the PyTorch operation @ or `torch.matmul`, there is a last way we can implement this matrix multiplication: einstein summation (einsum). This is a compact representation for combining products and sums in a general way. We write an equation like this\n",
|
|
"\n",
|
|
"```\n",
|
|
"ik,kj -> ij\n",
|
|
"```\n",
|
|
"\n",
|
|
"The left hand side represents the operands dimensions, separated by commas. Here we have two tensors taht each have two dimensions (i,k and k,j). The right hand side represents the result dimensions, so here we have a tensor with two dimensions i,j. \n",
|
|
"\n",
|
|
"There are essentially three rules of Einstein summation notation, namely:\n",
|
|
"\n",
|
|
"1. Repeated indices are implicitly summed over.\n",
|
|
"1. Each index can appear at most twice in any term.\n",
|
|
"1. Each term must contain identical non-repeated indices.\n",
|
|
"\n",
|
|
"So in the example above, since `k` is repeated, we sum over that index. In the end the above formula represents the matrix obtained when we put in (i,j) the sum of all the coefficients (i,k) in the first tensor multiplied by the coefficients (k,j) in the second tensor... which is the matrix product!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Einsteim summation is a very practical way of expressing operations involving indexing and sum of products. Note that you can have only one member in the left hand side. For instance\n",
|
|
"\n",
|
|
"```python\n",
|
|
"torch.einsum('ij->ji', a)\n",
|
|
"```\n",
|
|
"\n",
|
|
"returns the transpose of the matrix `a`. You can also have three or more members:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"torch.einsum('bi,ij,bj->b', a, b, c)\n",
|
|
"```\n",
|
|
"\n",
|
|
"will return a vector of size `b` where the `k`-th coordinate is the sum of the `a[k,i] b[i,j] c[k,j]`. This notation is getting really convenient when you have more dimensions because of batches, for instance if you have two batches of matrices and want compute the matrix product per batch, you would go: \n",
|
|
"\n",
|
|
"```python\n",
|
|
"torch.einsum('bik,bkj->bij', a, b)\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"68.7 µs ± 4.06 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 20 t5 = matmul(m1,m2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As we see, not only is it practical, but it's *very* fast. `einsum` is often the fastest way to do custom operations in PyTorch, without diving into C++ and CUDA. (But it's generally not as fast as carefully optimized CUDA code, as you see in the matrix multiplication example)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The forward and backward passes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now that we have defined `matmul` from scratch, we are ready to define our first neural net. As we saw in <<chapter_mnist_basics>>, to train it, we will need to compute all the gradients of a given a loss with respect to its parameters, which is known as the *backward pass*. The *forward pass* is computing the output of the model on a given input, which is just based on the matrix products we saw. As we define our first neural net, we will also delve in the problem of properly initializing the weights, which is crucial to make training start properly."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Defining and initializing a layer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We will take the example of a two-layer neural net first. As we saw, one layer can be expressed as `y = x @ w + b` with `x` out inputs, `y` our outputs, `w` the weights of the layer (which is of size numbe of inputs by neuron of neurons if we don't transpose like before) and `b` is the bias vector. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def lin(x, w, b): return x @ w + b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can stack two layers on top of the other, but since mathematically, the composition of two linear operations is another linear operation, this only makes sense if we put something non-linear in the middle called an activation function. The activation function most popularly used is a ReLU, which, as we saw, is just the maximum of `x` and `0`. \n",
|
|
"\n",
|
|
"We won't actually train our model in this chapter so we use random tensors for our inputs and targets. Let's say our inputs are 200 vectors of size 100, which we group in one batch, and our targets are 200 random floats."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"y = torch.randn(200)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For our two-layers model we will need two weight matrices and two bias vectors. Let's say we have a hidden size of 50 and the output size is 1 (for one of our input, the corresponding output is one float in this toy example). We initialize the weights randomly and the bias at zero. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"w1 = torch.randn(100,50)\n",
|
|
"b1 = torch.zeros(50)\n",
|
|
"w2 = torch.randn(50,1)\n",
|
|
"b2 = torch.zeros(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then the result of our first layer is simply:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([200, 50])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"l1 = lin(x, w1, b1)\n",
|
|
"l1.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that this formula works with our batch of inputs, and returns a batch of hidden state: `l1` is a matrix of 200 (our batch size) by 50 (our hidden size).\n",
|
|
"\n",
|
|
"There is a problem with the way our model was initiliazed however. To understand it, we need to look at the mean and standard deviation (std) of `l1`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(0.0019), tensor(10.1058))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"l1.mean(), l1.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The mean is close to zero, which is understandable since both our input and weight matrix have a mean close to zero. However the standard deviation, which represents how far away our activation go from the mean, went from 1 to 10. This is a really big problem because that's with just one layer. Modern neural nets can have hundred of layers, so if each of them multiply the scale of our activations by 10, by the end of the last layer we won't have numbers representable by a computer.\n",
|
|
"\n",
|
|
"Indeed, if we make just 50 multiplications between x and random matrices of size 100 x 100:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[nan, nan, nan, nan, nan],\n",
|
|
" [nan, nan, nan, nan, nan],\n",
|
|
" [nan, nan, nan, nan, nan],\n",
|
|
" [nan, nan, nan, nan, nan],\n",
|
|
" [nan, nan, nan, nan, nan]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"for i in range(50): x = x @ torch.randn(100,100)\n",
|
|
"x[0:5,0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The result is nans everywhere. So maybe the scale of our matrix was too big, and we need to have smaller weights. But if we use too small weights we will have the opposite problem: the scale of our activations will get from 1 to 0.1 and after 100 layers, we'll be left with zeros everywhere."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0., 0., 0., 0., 0.],\n",
|
|
" [0., 0., 0., 0., 0.],\n",
|
|
" [0., 0., 0., 0., 0.],\n",
|
|
" [0., 0., 0., 0., 0.],\n",
|
|
" [0., 0., 0., 0., 0.]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"for i in range(50): x = x @ (torch.randn(100,100) * 0.01)\n",
|
|
"x[0:5,0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"So we have to scale our weights matrices exactly right so that the standard deviation of our activations stays at 1. We can compute the exact value mathematically, and this has been done by Xavier Glorot and Yoshua Bengio in [Understanding the difficulty of training deep feedforward neural networks](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf). The right scale for a given layer is $1/\\sqrt{n_{in}}$, where $n_{in}$ represents the number of inputs.\n",
|
|
"\n",
|
|
"In our case, if we have 100 inputs, we should scale our weight matrices by 0.1:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[ 0.7554, 0.6167, -0.1757, -1.5662, 0.5644],\n",
|
|
" [-0.1987, 0.6292, 0.3283, -1.1538, 0.5416],\n",
|
|
" [ 0.6106, 0.2556, -0.0618, -0.9463, 0.4445],\n",
|
|
" [ 0.4484, 0.7144, 0.1164, -0.8626, 0.4413],\n",
|
|
" [ 0.3463, 0.5930, 0.3375, -0.9486, 0.5643]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"for i in range(50): x = x @ (torch.randn(100,100) * 0.1)\n",
|
|
"x[0:5,0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally some numbers that are neither zeros nor infinity! Notice how stable the scale of our activations is, even after those 50 fake layers:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.7042)"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"You can play a little bit with the values of the scale and notice that even a slight variation from 0.1 will get you either to very small or very alrge numbers, so initializing the weights properly is extremely important. Let's go back to our neural net. Since we messed a bit with our inputs we need to redefine them:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"y = torch.randn(200)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"and for our weights, we use the right scale, which is known as *Xavier initialization* (or *Glorot initialization*)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from math import sqrt\n",
|
|
"w1 = torch.randn(100,50) / sqrt(100)\n",
|
|
"b1 = torch.zeros(50)\n",
|
|
"w2 = torch.randn(50,1) / sqrt(50)\n",
|
|
"b2 = torch.zeros(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now if compute the result of the first layer, we can check the mean and standard deviation are under control:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(-0.0050), tensor(1.0000))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"l1 = lin(x, w1, b1)\n",
|
|
"l1.mean(),l1.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Very good, now we need to go through a relu, so let's define one. A relu removes the negatives and replace them by 0, which is another way of saying it clamps our tensor at 0."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def relu(x): return x.clamp_min(0.)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now let's make our activations go through a relu"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(0.3961), tensor(0.5783))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"l2 = relu(l1)\n",
|
|
"l2.mean(),l2.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And we're back to square one: the mean of our activation has gone to 0.4 (which is understandable since we removed the negatives) and the std went down to 0.5-0.6. So like before, after a few layers we will probably get to 0:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0.0000e+00, 1.9689e-08, 4.2820e-08, 0.0000e+00, 0.0000e+00],\n",
|
|
" [0.0000e+00, 1.6701e-08, 4.3501e-08, 0.0000e+00, 0.0000e+00],\n",
|
|
" [0.0000e+00, 1.0976e-08, 3.0411e-08, 0.0000e+00, 0.0000e+00],\n",
|
|
" [0.0000e+00, 1.8457e-08, 4.9469e-08, 0.0000e+00, 0.0000e+00],\n",
|
|
" [0.0000e+00, 1.9949e-08, 4.1643e-08, 0.0000e+00, 0.0000e+00]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"for i in range(50): x = relu(x @ (torch.randn(100,100) * 0.1))\n",
|
|
"x[0:5,0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"So our initialization wasn't right. This is because at the same the previous article was written, the popular activation in a neural net was the hyperbolic tangent (which is the one they use) and that initialization doesn't account for our ReLU. Fortunately someone else has done the math for us and computed the right scale we should use. Kaiming He et al. in [Delving Deep into Rectifiers: Surpassing Human-Level Performance](https://arxiv.org/abs/1502.01852) (which we've seen before--it's the article that introduced the ResNet) show we should use the following scale instead: $\\sqrt{2 / n_{in}}$ where $n_{in}$ is the number of inputs of our model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0.2871, 0.0000, 0.0000, 0.0000, 0.0026],\n",
|
|
" [0.4546, 0.0000, 0.0000, 0.0000, 0.0015],\n",
|
|
" [0.6178, 0.0000, 0.0000, 0.0180, 0.0079],\n",
|
|
" [0.3333, 0.0000, 0.0000, 0.0545, 0.0000],\n",
|
|
" [0.1940, 0.0000, 0.0000, 0.0000, 0.0096]])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"for i in range(50): x = relu(x @ (torch.randn(100,100) * sqrt(2/100)))\n",
|
|
"x[0:5,0:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And indeed if we use it we can check our numbers aren't all zeroed this time. So let's go back to the definition of our neural net and use this initialization (which is named *Kaiming initialization* or *He initialization*)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.randn(200, 100)\n",
|
|
"y = torch.randn(200)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"w1 = torch.randn(100,50) * sqrt(2 / 100)\n",
|
|
"b1 = torch.zeros(50)\n",
|
|
"w2 = torch.randn(50,1) * sqrt(2 / 50)\n",
|
|
"b2 = torch.zeros(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now after going through the first linear layer and relu, let's look at the scale of our activations:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(0.5661), tensor(0.8339))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"l1 = lin(x, w1, b1)\n",
|
|
"l2 = relu(l1)\n",
|
|
"l2.mean(), l2.std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now that our weights are properly initialized, we can define our whole model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def model(x):\n",
|
|
" l1 = lin(x, w1, b1)\n",
|
|
" l2 = relu(l1)\n",
|
|
" l3 = lin(l2, w2, b2)\n",
|
|
" return l3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This is the forward pass, now all there is left to do is to compare our output to the labels we have (random numbers, in this example) with a loss function. In this case, we will use the mean squared error. (It's a toy problem in any case and this is the easiest loss function to use for what is next, computing the gradients.)\n",
|
|
"\n",
|
|
"The only subtlety is that our output and target don't have exactly the same shape: after going though the model, we get an output like this."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([200, 1])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"out = model(x)\n",
|
|
"out.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To get rid of this trailing 1 dimension, we use the `squeeze` function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And now we are ready to compute our loss."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loss = mse(out, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Gradients and backward pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We've seen that PyTorch computes all the gradient we need with a magic call to `loss.backward()` but how is it done behind the scenes?\n",
|
|
"\n",
|
|
"Now comes the part where we need to compute the gradients of the loss with respect to all the weights of our model, so all the floats in `w1`, `b1`, `w2` and `b2`. For this, we will need a bit of math, specifically the chain rule. If you don't remember it from high school, this is the rule of calculus that guides how we can compute the derivative of a composed function:\n",
|
|
"\n",
|
|
"$$(g \\circ f)'(x) = g'(f(x)) f'(x)$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"> j: I find this notation very hard to wrap my head around, so instead I like to think of it as: if `y = g(u)` and `u=f(x)`; then `dy/dx = dy/du * du/dx`. The two notations mean the same thing, so use whatever works for you."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Our loss is a big composition of different functions: mean-squared error (which is in turn the composition of a mean and a power of two), the second linear layer, a relu and the first linear layer. For instance, we want the gradients of the loss with respect to `b2` and our loss is defined by\n",
|
|
"\n",
|
|
"```\n",
|
|
"loss = mse(out,y) = mse(lin(l2, w2, b2), y)\n",
|
|
"```\n",
|
|
"\n",
|
|
"The chain rule tells us that we have\n",
|
|
"$$\\frac{\\text{d} loss}{\\text{d} b_{2}} = \\frac{\\text{d} loss}{\\text{d} out} \\times \\frac{\\text{d} out}{\\text{d} b_{2}} = \\frac{\\text{d}}{\\text{d} out} mse(out, y) \\times \\frac{\\text{d}}{\\text{d} b_{2}} lin(l_{2}, w_{2}, b_{2})$$\n",
|
|
"\n",
|
|
"To compute the gradients of the loss with respect to $b_{2}$, we first need the gradients of the loss with respect to our output $out$. It's the same if we want the gradients of the loss with respect to $w_{2}$. Then, to get the gradients of the loss with respect to $b_{1}$ or $w_{1}$, we will need the gradients of the loss with respect to $l_{1}$, which in turn requires the gradients of the loss with respect to $l_{2}$, which will need the gradients of the loss with respect to $out$.\n",
|
|
"\n",
|
|
"So to compute all the gradients we need for the update, we need to begin from the output of the model and work our way *backward*, one layer after the other, which is why this step is known as *backpropagation*. We can automate it by having each function we implemented (`relu`, `mse`, `lin`) provide its backward step, that is how to derive the gradients of the loss with respect to the input(s) from the gradient of the loss with respect to the output.\n",
|
|
"\n",
|
|
"Here we populate those gradients in an attribute of each tensor, a bit like PyTorch does with `.grad`. \n",
|
|
"\n",
|
|
"The first are the gradients of the loss with respect to the output of our model (which is the input of the loss function). We have to undo the squeeze we did in `mse` then we use the formula that gives us the derivative of $x^{2}$: $2x$. The derivative of the mean is just 1/n where n is the number of elements in our input."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mse_grad(inp, targ): \n",
|
|
" # grad of loss with respect to output of previous layer\n",
|
|
" inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For the gradients of the relu and our linear layer, we use the gradients of the loss with respect to the output (in `out.g`) and apply the chain rule to compute the gradients of the loss with respect to the output (in `inp.g`). The chain rule tells us that `inp.g = relu'(inp) * out.g`. The derivative of relu is either 0 (when inputs are negative) or 1 (when inputs are positive), so this gives us:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def relu_grad(inp, out):\n",
|
|
" # grad of relu with respect to input activations\n",
|
|
" inp.g = (inp>0).float() * out.g"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The scheme is the same to compute the gradients of the loss with respect to the inputs, weights and bias in the linear layer. We won't linger on the mathematical formulas that define them since they're not important for our purposes--but do check out Khan Academy's excellent calculus lessons if you're interested in this topic."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def lin_grad(inp, out, w, b):\n",
|
|
" # grad of matmul with respect to input\n",
|
|
" inp.g = out.g @ w.t()\n",
|
|
" w.g = inp.t() @ out.g\n",
|
|
" b.g = out.g.sum(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Sidebar: SymPy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"An extremely useful library for working with calculus is *SymPy*. SymPy is a library for symbolic computation, which is defined in the SymPy documentation:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"> : Symbolic computation deals with the computation of mathematical objects symbolically. This means that the mathematical objects are represented exactly, not approximately, and mathematical expressions with unevaluated variables are left in symbolic form."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To do symbolic computation, first define a *symbol*, and then do a computation, like so:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 2 sx$"
|
|
],
|
|
"text/plain": [
|
|
"2*sx"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from sympy import symbols,diff\n",
|
|
"sx,sy = symbols('sx sy')\n",
|
|
"diff(sx**2, sx)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here, SymPy has taken the derivative of `x**2` for us! SymPy can take the derivative of complicated compound expressions, and can also simplify and factor equations, and much more. There's really not much reason for anyone to do calculus manually nowadays--for calculating gradients, PyTorch does it for us, and for showing the equation, SymPy does it for us!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### End sidebar"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once we have have defined those functions we can use them to write the backward pass. Since each gradient is automatically populated in the right tensor, we don't need to store the results of those `_grad` functions anywhere, we just need to execute them in the reverse order as the forward pass, to make sure that in each function, `out.g` exists."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def forward_and_backward(inp, targ):\n",
|
|
" # forward pass:\n",
|
|
" l1 = inp @ w1 + b1\n",
|
|
" l2 = relu(l1)\n",
|
|
" out = l2 @ w2 + b2\n",
|
|
" # we don't actually need the loss in backward!\n",
|
|
" loss = mse(out, targ)\n",
|
|
" \n",
|
|
" # backward pass:\n",
|
|
" mse_grad(out, targ)\n",
|
|
" lin_grad(l2, out, w2, b2)\n",
|
|
" relu_grad(l1, l2)\n",
|
|
" lin_grad(inp, l1, w1, b1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And now we can access to the gradients of our model parameters in `w1.g`, `b1.g`, `w2.g`, `b2.g`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Refactor the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The three functions we used have two associated functions: a forward pass and a backward pass. Instead of writing them separately, we can create a class to wrap them together. That class can also store the inputs and outputs for the backward pass, this way we will just have to call `backward()`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Relu():\n",
|
|
" def __call__(self, inp):\n",
|
|
" self.inp = inp\n",
|
|
" self.out = inp.clamp_min(0.)\n",
|
|
" return self.out\n",
|
|
" \n",
|
|
" def backward(self): self.inp.g = (self.inp>0).float() * self.out.g"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `__call__` name is a magic name in PyThon that will make our class callable. This what will be executed when we type `y = Relu()(x)`. We can do the same for our linear layer and the MSE loss."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Lin():\n",
|
|
" def __init__(self, w, b): self.w,self.b = w,b\n",
|
|
" \n",
|
|
" def __call__(self, inp):\n",
|
|
" self.inp = inp\n",
|
|
" self.out = inp@self.w + self.b\n",
|
|
" return self.out\n",
|
|
" \n",
|
|
" def backward(self):\n",
|
|
" self.inp.g = self.out.g @ self.w.t()\n",
|
|
" self.w.g = self.inp.t() @ self.out.g\n",
|
|
" self.b.g = self.out.g.sum(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Mse():\n",
|
|
" def __call__(self, inp, targ):\n",
|
|
" self.inp = inp\n",
|
|
" self.targ = targ\n",
|
|
" self.out = (inp.squeeze() - targ).pow(2).mean()\n",
|
|
" return self.out\n",
|
|
" \n",
|
|
" def backward(self):\n",
|
|
" self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.targ.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then we can put everything in a model that we initiate with our tensors `w1`, `b1`, `w2`, `b2`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Model():\n",
|
|
" def __init__(self, w1, b1, w2, b2):\n",
|
|
" self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n",
|
|
" self.loss = Mse()\n",
|
|
" \n",
|
|
" def __call__(self, x, targ):\n",
|
|
" for l in self.layers: x = l(x)\n",
|
|
" return self.loss(x, targ)\n",
|
|
" \n",
|
|
" def backward(self):\n",
|
|
" self.loss.backward()\n",
|
|
" for l in reversed(self.layers): l.backward()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"What is really nice about this refactoring and registering things as layers of our model is that the forward and backward pass are now really easy to write. If we want to instantiate our model, we just need to write:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = Model(w1, b1, w2, b2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The forward pass would then be executed with:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loss = model(x, y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"And the backward pass with:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.backward()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Going to PyTorch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The three classes we wrote for `Lin`, `Mse` and `Relu` have a lot in common, so we could make them all inherit from the same basic class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LayerFunction():\n",
|
|
" def __call__(self, *args):\n",
|
|
" self.args = args\n",
|
|
" self.out = self.forward(*args)\n",
|
|
" return self.out\n",
|
|
" \n",
|
|
" def forward(self): raise Exception('not implemented')\n",
|
|
" def bwd(self): raise Exception('not implemented')\n",
|
|
" def backward(self): self.bwd(self.out, *self.args)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then we just need to implement `forward` and `bwd` in each of our subclass."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Relu(LayerFunction):\n",
|
|
" def forward(self, inp): return inp.clamp_min(0.)\n",
|
|
" def bwd(self, out, inp): inp.g = (inp>0).float() * out.g"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Lin(LayerFunction):\n",
|
|
" def __init__(self, w, b): self.w,self.b = w,b\n",
|
|
" \n",
|
|
" def forward(self, inp): return inp@self.w + self.b\n",
|
|
" \n",
|
|
" def bwd(self, out, inp):\n",
|
|
" inp.g = out.g @ self.w.t()\n",
|
|
" self.w.g = self.inp.t() @ self.out.g\n",
|
|
" self.b.g = out.g.sum(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Mse(LayerFunction):\n",
|
|
" def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n",
|
|
" def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then our model can be the same as before. This is getting closer and closer to what PyTorch does. Each basic function we need to differentiate is written as a `torch.autograd.Function` object that has a `forward` and a `backward` method. PyTorch will then keep trace of any computation we do to be able to properly run the backward pass unless we set the `requires_grad` attribute of our tensors to `False`.\n",
|
|
"\n",
|
|
"Writing one is (almost) as easy as we had before. The difference is that we choose what to save and what to put in a context variable (so that we make sure we don't save anything we don't need) and that we return the gradients in the `backward` pass. It's very rare to have to write your own `Function` but if you ever need something exotic or want to mess with the gradients of a regular function, here is how we write one:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.autograd import Function\n",
|
|
"\n",
|
|
"class MyRelu(Function):\n",
|
|
" @staticmethod\n",
|
|
" def forward(ctx, i):\n",
|
|
" result = i.clamp_min(0.)\n",
|
|
" ctx.save_for_backward(i)\n",
|
|
" return result\n",
|
|
" \n",
|
|
" @staticmethod\n",
|
|
" def backward(ctx, grad_output):\n",
|
|
" i, = ctx.saved_tensors\n",
|
|
" return grad_output * (i>0).float()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then the structure used to build a more complex model that takes advantage of those functions is a `torch.nn.Module`. This is the base structure for all models and all the neural nets you have seen up until now where from that class. It mostly helps to register all the trainable parameters, which as we've seen can be used in the training loop.\n",
|
|
"\n",
|
|
"To implement a `nn.Module` you just need to\n",
|
|
"- Make sure the superclass `__init__` is called first when you initiliaze it,\n",
|
|
"- Define any parameter of the model as attributes with `nn.Parameter`,\n",
|
|
"- Define a `forward` function that returns the output of your model.\n",
|
|
"\n",
|
|
"As an example, here is the linear layer from scratch:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"\n",
|
|
"class LinearLayer(nn.Module):\n",
|
|
" def __init__(self, n_in, n_out):\n",
|
|
" super().__init__()\n",
|
|
" self.weight = nn.Parameter(torch.randn(n_out, n_in) * sqrt(2/n_in))\n",
|
|
" self.bias = nn.Parameter(torch.zeros(n_out))\n",
|
|
" \n",
|
|
" def forward(self, x): return x @ self.weight.t() + self.bias"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As you see, this class automatically keeps track of what parameters have been defined:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([2, 10]), torch.Size([2]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lin = LinearLayer(10,2)\n",
|
|
"p1,p2 = lin.parameters()\n",
|
|
"p1.shape,p2.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"It is thanks to this feature of `nn.Module` that we can just say `opt.step()` and have an optimizer loop through the parameters and update each one.\n",
|
|
"\n",
|
|
"Note that in PyTorch, the weights are stored as an `n_out x n_in` matrix, which is why we have the transpose in the forward pass.\n",
|
|
"\n",
|
|
"By using the linear layer from PyTorch (which uses the Kaiming initialization as well), the model we have seen during this chapter can be written like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Model(nn.Module):\n",
|
|
" def __init__(self, n_in, nh, n_out):\n",
|
|
" super().__init__()\n",
|
|
" self.layers = nn.Sequential(\n",
|
|
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
|
|
" self.loss = mse\n",
|
|
" \n",
|
|
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"fastai provides its own variant of `Module` which is identical to `nn.Module`, but doesn't require you to call `super().__init__()` (it does that for you automatically):"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Model(Module):\n",
|
|
" def __init__(self, n_in, nh, n_out):\n",
|
|
" self.layers = nn.Sequential(\n",
|
|
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
|
|
" self.loss = mse\n",
|
|
" \n",
|
|
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In the next chapter, we will start from such a model and see how we build a training loop from scratch and refactor it to what we've been using in previous chapters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Things to remember"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"- A neural net is basically a bunch of matrix multiplications with non-linearities in-between.\n",
|
|
"- Python is slow so to write fast code we have to vectorize it and take advantage of element-wise arithmetic or broadcasting.\n",
|
|
"- Two tensors are broadcastable if the dimensions starting from the end and going backward match (they are the same or one of them is 1). To make tensors broadcastable, we may need to add dimensions of size 1 with `unsqueeze` or a `None` index.\n",
|
|
"- Properly initiliazing a neural net is crucial to get training started. Kaiming initialization should be used when we have ReLU non-linearities.\n",
|
|
"- The backward pass is the chain rule applied multiple times, computing the gradients from the output of our model and going back, one layer at a time.\n",
|
|
"- When subclassing `nn.Module` (if not using fastai's `Module`) we have to call the superclass `__init__` method in our `__init__` method and we have to define a `forward` function that takes an input and returns the desired result."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Questionnaire"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. Write the Python code to implement a single neuron.\n",
|
|
"1. Write the Python code to implement ReLU.\n",
|
|
"1. Write the Python code for a dense layer in terms of matrix multiplication.\n",
|
|
"1. Write the Python code for a dense layer in plain Python (that is with list comprehensions and functionality built into Python).\n",
|
|
"1. What is the hidden size of a layer?\n",
|
|
"1. What does the `t` method to in PyTorch?\n",
|
|
"1. Why is matrix multiplication written in plain Python very slow?\n",
|
|
"1. In matmul, why is `ac==br`?\n",
|
|
"1. In Jupyter notebook, how do you measure the time taken for a single cell to execute?\n",
|
|
"1. What is elementwise arithmetic?\n",
|
|
"1. Write the PyTorch code to test whether every element of `a` is greater than the corresponding element of `b`.\n",
|
|
"1. What is a rank-0 tensor? How do you convert it to a plain Python data type?\n",
|
|
"1. What does this return, and why?: `tensor([1,2]) + tensor([1])`\n",
|
|
"1. What does this return, and why?: `tensor([1,2]) + tensor([1,2,3])`\n",
|
|
"1. How does elementwise arithmetic help us speed up matmul?\n",
|
|
"1. What are the broadcasting rules?\n",
|
|
"1. What is `expand_as`? Show an example of how it can be used to match the results of broadcasting.\n",
|
|
"1. How does `unsqueeze` help us to solve certain broadcasting problems?\n",
|
|
"1. How can you use indexing to do the same operation as `unsqueeze`?\n",
|
|
"1. How do we show the actual contents of the memory used for a tensor?\n",
|
|
"1. When adding a vector of size 3 to a matrix of size 3 x 3, are the elements of the vector added to each row, or each column of the matrix? (Be sure to check your answer by running this code in a notebook.)\n",
|
|
"1. Do broadcasting and `expand_as` result in increased memory use? Why or why not?\n",
|
|
"1. Implement matmul using Einstein summation.\n",
|
|
"1. What does a repeated index letter represent on the left-hand side of einsum?\n",
|
|
"1. What are the three rules of Einstein summation notation? Why?\n",
|
|
"1. What is the forward pass, and the backward pass, of a neural network?\n",
|
|
"1. Why do we need to store some of the activations calculated for intermediate layers in the forward pass?\n",
|
|
"1. What is the downside of having activations with a standard deviation too far away from one?\n",
|
|
"1. How can weight initialisation help avoid this problem?\n",
|
|
"1. What is the formula to initialise weights such that we get a standard deviation of one, for a plain linear layer; for a linear layer followed by ReLU?\n",
|
|
"1. Why do we sometimes have to use the `squeeze` method in loss functions?\n",
|
|
"1. What does the argument to the squeeze method do? Why might it be important to include this argument, even though PyTorch does not require it?\n",
|
|
"1. What is the chain rule? Show the equation in either of the two forms shown in this chapter.\n",
|
|
"1. Show how to calculate the gradients of `mse(lin(l2, w2, b2), y)` using the chain rule.\n",
|
|
"1. What is the gradient of relu? Show in math or code. (You shouldn't need to commit this to memory—try to figure it using your knowledge of the shape of the function.)\n",
|
|
"1. In what order do we need to call the `*_grad` functions in the backward pass? Why?\n",
|
|
"1. What is `__call__`?\n",
|
|
"1. What methods do we need to implement when writing a `torch.autograd.Function`?\n",
|
|
"1. Write `nn.Linear` from scratch, and test it works.\n",
|
|
"1. What is the difference between `nn.Module` and fastai's `Module`?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Further research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. Implement relu as a `torch.autograd.Function` and train a model with it.\n",
|
|
"1. If you are mathematically inclined, find out what the gradients of a linear layer are in maths notation. Map that to the implementation we saw in this chapter.\n",
|
|
"1. Learn about the `unfold` method in PyTorch, and use it along with matrix multiplication to implement your own 2d convolution function, and train a CNN that uses it.\n",
|
|
"1. Implement all what is in this chapter using numpy instead of PyTorch. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"split_at_heading": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|