fastbook/clean/17_foundations.ipynb
Jeremy Howard a64c789ebe colab
2022-04-25 15:43:24 +10:00

1125 lines
26 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Neural Net from the Foundations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a Neural Net Layer from Scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Modeling a Neuron"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matrix Multiplication from Scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import tensor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape # n_rows * n_cols\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
" for j in range(bc):\n",
" for k in range(ac): c[i,j] += a[i,k] * b[k,j]\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m1 = torch.randn(5,28*28)\n",
"m2 = torch.randn(784,10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%time t1=matmul(m1, m2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit -n 20 t2=m1@m2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Elementwise Arithmetic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = tensor([10., 6, -4])\n",
"b = tensor([2., 8, 7])\n",
"a + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a < b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(a < b).all(), (a==b).all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(a + b).mean().item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"m*m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n = tensor([[1., 2, 3], [4,5,6]])\n",
"m*n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
" for j in range(bc): c[i,j] = (a[i] * b[:,j]).sum()\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit -n 20 t3 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Broadcasting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting with a scalar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = tensor([10., 6, -4])\n",
"a > 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"(m - 5) / 2.73"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting a vector to a matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"m.shape,c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m + c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c.expand_as(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t = c.expand_as(m)\n",
"t.storage()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t.stride(), t.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c + m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6]])\n",
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = tensor([10.,20])\n",
"m = tensor([[1., 2, 3], [4,5,6]])\n",
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"c = c.unsqueeze(1)\n",
"m.shape,c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t = c.expand_as(m)\n",
"t.storage()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t.stride(), t.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = tensor([10.,20,30])\n",
"c.shape, c.unsqueeze(0).shape,c.unsqueeze(1).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c.shape, c[None,:].shape,c[:,None].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c[None].shape,c[...,None].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
"# c[i,j] = (a[i,:] * b[:,j]).sum() # previous\n",
" c[i] = (a[i ].unsqueeze(-1) * b).sum(dim=0)\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit -n 20 t4 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting rules"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Einstein Summation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%timeit -n 20 t5 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Forward and Backward Passes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining and Initializing a Layer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def lin(x, w, b): return x @ w + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w1 = torch.randn(100,50)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"l1 = lin(x, w1, b1)\n",
"l1.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"l1.mean(), l1.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ torch.randn(100,100)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ (torch.randn(100,100) * 0.01)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ (torch.randn(100,100) * 0.1)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from math import sqrt\n",
"w1 = torch.randn(100,50) / sqrt(100)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1) / sqrt(50)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"l1 = lin(x, w1, b1)\n",
"l1.mean(),l1.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def relu(x): return x.clamp_min(0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"l2 = relu(l1)\n",
"l2.mean(),l2.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = relu(x @ (torch.randn(100,100) * 0.1))\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = relu(x @ (torch.randn(100,100) * sqrt(2/100)))\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w1 = torch.randn(100,50) * sqrt(2 / 100)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1) * sqrt(2 / 50)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"l1 = lin(x, w1, b1)\n",
"l2 = relu(l1)\n",
"l2.mean(), l2.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model(x):\n",
" l1 = lin(x, w1, b1)\n",
" l2 = relu(l1)\n",
" l3 = lin(l2, w2, b2)\n",
" return l3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out = model(x)\n",
"out.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = mse(out, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradients and the Backward Pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mse_grad(inp, targ): \n",
" # grad of loss with respect to output of previous layer\n",
" inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def relu_grad(inp, out):\n",
" # grad of relu with respect to input activations\n",
" inp.g = (inp>0).float() * out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def lin_grad(inp, out, w, b):\n",
" # grad of matmul with respect to input\n",
" inp.g = out.g @ w.t()\n",
" w.g = inp.t() @ out.g\n",
" b.g = out.g.sum(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: SymPy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sympy import symbols,diff\n",
"sx,sy = symbols('sx sy')\n",
"diff(sx**2, sx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def forward_and_backward(inp, targ):\n",
" # forward pass:\n",
" l1 = inp @ w1 + b1\n",
" l2 = relu(l1)\n",
" out = l2 @ w2 + b2\n",
" # we don't actually need the loss in backward!\n",
" loss = mse(out, targ)\n",
" \n",
" # backward pass:\n",
" mse_grad(out, targ)\n",
" lin_grad(l2, out, w2, b2)\n",
" relu_grad(l1, l2)\n",
" lin_grad(inp, l1, w1, b1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Refactoring the Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Relu():\n",
" def __call__(self, inp):\n",
" self.inp = inp\n",
" self.out = inp.clamp_min(0.)\n",
" return self.out\n",
" \n",
" def backward(self): self.inp.g = (self.inp>0).float() * self.out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Lin():\n",
" def __init__(self, w, b): self.w,self.b = w,b\n",
" \n",
" def __call__(self, inp):\n",
" self.inp = inp\n",
" self.out = inp@self.w + self.b\n",
" return self.out\n",
" \n",
" def backward(self):\n",
" self.inp.g = self.out.g @ self.w.t()\n",
" self.w.g = self.inp.t() @ self.out.g\n",
" self.b.g = self.out.g.sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Mse():\n",
" def __call__(self, inp, targ):\n",
" self.inp = inp\n",
" self.targ = targ\n",
" self.out = (inp.squeeze() - targ).pow(2).mean()\n",
" return self.out\n",
" \n",
" def backward(self):\n",
" x = (self.inp.squeeze()-self.targ).unsqueeze(-1)\n",
" self.inp.g = 2.*x/self.targ.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model():\n",
" def __init__(self, w1, b1, w2, b2):\n",
" self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n",
" self.loss = Mse()\n",
" \n",
" def __call__(self, x, targ):\n",
" for l in self.layers: x = l(x)\n",
" return self.loss(x, targ)\n",
" \n",
" def backward(self):\n",
" self.loss.backward()\n",
" for l in reversed(self.layers): l.backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Model(w1, b1, w2, b2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = model(x, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Going to PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LayerFunction():\n",
" def __call__(self, *args):\n",
" self.args = args\n",
" self.out = self.forward(*args)\n",
" return self.out\n",
" \n",
" def forward(self): raise Exception('not implemented')\n",
" def bwd(self): raise Exception('not implemented')\n",
" def backward(self): self.bwd(self.out, *self.args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Relu(LayerFunction):\n",
" def forward(self, inp): return inp.clamp_min(0.)\n",
" def bwd(self, out, inp): inp.g = (inp>0).float() * out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Lin(LayerFunction):\n",
" def __init__(self, w, b): self.w,self.b = w,b\n",
" \n",
" def forward(self, inp): return inp@self.w + self.b\n",
" \n",
" def bwd(self, out, inp):\n",
" inp.g = out.g @ self.w.t()\n",
" self.w.g = inp.t() @ self.out.g\n",
" self.b.g = out.g.sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Mse(LayerFunction):\n",
" def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n",
" def bwd(self, out, inp, targ): \n",
" inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.autograd import Function\n",
"\n",
"class MyRelu(Function):\n",
" @staticmethod\n",
" def forward(ctx, i):\n",
" result = i.clamp_min(0.)\n",
" ctx.save_for_backward(i)\n",
" return result\n",
" \n",
" @staticmethod\n",
" def backward(ctx, grad_output):\n",
" i, = ctx.saved_tensors\n",
" return grad_output * (i>0).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class LinearLayer(nn.Module):\n",
" def __init__(self, n_in, n_out):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.randn(n_out, n_in) * sqrt(2/n_in))\n",
" self.bias = nn.Parameter(torch.zeros(n_out))\n",
" \n",
" def forward(self, x): return x @ self.weight.t() + self.bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lin = LinearLayer(10,2)\n",
"p1,p2 = lin.parameters()\n",
"p1.shape,p2.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" super().__init__()\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
" self.loss = mse\n",
" \n",
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model(Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
" self.loss = mse\n",
" \n",
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Write the Python code to implement a single neuron.\n",
"1. Write the Python code to implement ReLU.\n",
"1. Write the Python code for a dense layer in terms of matrix multiplication.\n",
"1. Write the Python code for a dense layer in plain Python (that is, with list comprehensions and functionality built into Python).\n",
"1. What is the \"hidden size\" of a layer?\n",
"1. What does the `t` method do in PyTorch?\n",
"1. Why is matrix multiplication written in plain Python very slow?\n",
"1. In `matmul`, why is `ac==br`?\n",
"1. In Jupyter Notebook, how do you measure the time taken for a single cell to execute?\n",
"1. What is \"elementwise arithmetic\"?\n",
"1. Write the PyTorch code to test whether every element of `a` is greater than the corresponding element of `b`.\n",
"1. What is a rank-0 tensor? How do you convert it to a plain Python data type?\n",
"1. What does this return, and why? `tensor([1,2]) + tensor([1])`\n",
"1. What does this return, and why? `tensor([1,2]) + tensor([1,2,3])`\n",
"1. How does elementwise arithmetic help us speed up `matmul`?\n",
"1. What are the broadcasting rules?\n",
"1. What is `expand_as`? Show an example of how it can be used to match the results of broadcasting.\n",
"1. How does `unsqueeze` help us to solve certain broadcasting problems?\n",
"1. How can we use indexing to do the same operation as `unsqueeze`?\n",
"1. How do we show the actual contents of the memory used for a tensor?\n",
"1. When adding a vector of size 3 to a matrix of size 3×3, are the elements of the vector added to each row or each column of the matrix? (Be sure to check your answer by running this code in a notebook.)\n",
"1. Do broadcasting and `expand_as` result in increased memory use? Why or why not?\n",
"1. Implement `matmul` using Einstein summation.\n",
"1. What does a repeated index letter represent on the left-hand side of einsum?\n",
"1. What are the three rules of Einstein summation notation? Why?\n",
"1. What are the forward pass and backward pass of a neural network?\n",
"1. Why do we need to store some of the activations calculated for intermediate layers in the forward pass?\n",
"1. What is the downside of having activations with a standard deviation too far away from 1?\n",
"1. How can weight initialization help avoid this problem?\n",
"1. What is the formula to initialize weights such that we get a standard deviation of 1 for a plain linear layer, and for a linear layer followed by ReLU?\n",
"1. Why do we sometimes have to use the `squeeze` method in loss functions?\n",
"1. What does the argument to the `squeeze` method do? Why might it be important to include this argument, even though PyTorch does not require it?\n",
"1. What is the \"chain rule\"? Show the equation in either of the two forms presented in this chapter.\n",
"1. Show how to calculate the gradients of `mse(lin(l2, w2, b2), y)` using the chain rule.\n",
"1. What is the gradient of ReLU? Show it in math or code. (You shouldn't need to commit this to memory—try to figure it using your knowledge of the shape of the function.)\n",
"1. In what order do we need to call the `*_grad` functions in the backward pass? Why?\n",
"1. What is `__call__`?\n",
"1. What methods must we implement when writing a `torch.autograd.Function`?\n",
"1. Write `nn.Linear` from scratch, and test it works.\n",
"1. What is the difference between `nn.Module` and fastai's `Module`?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further Research"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Implement ReLU as a `torch.autograd.Function` and train a model with it.\n",
"1. If you are mathematically inclined, find out what the gradients of a linear layer are in mathematical notation. Map that to the implementation we saw in this chapter.\n",
"1. Learn about the `unfold` method in PyTorch, and use it along with matrix multiplication to implement your own 2D convolution function. Then train a CNN that uses it.\n",
"1. Implement everything in this chapter using NumPy instead of PyTorch. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}