fastbook/clean/17_foundations.ipynb
2020-03-06 10:19:03 -08:00

1566 lines
34 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from fastai.gen_doc.nbdoc import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A neural net from the foundations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A neural net layer from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Modeling a neuron"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matrix multiplication from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import tensor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape # n_rows * n_cols\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
" for j in range(bc):\n",
" for k in range(ac): c[i,j] += a[i,k] * b[k,j]\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m1 = torch.randn(5,28*28)\n",
"m2 = torch.randn(784,10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.15 s, sys: 4.09 ms, total: 1.15 s\n",
"Wall time: 1.15 s\n"
]
}
],
"source": [
"%time t1=matmul(m1, m2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14 µs ± 8.95 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
]
}
],
"source": [
"%timeit -n 20 t2=m1@m2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([12., 14., 3.])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = tensor([10., 6, -4])\n",
"b = tensor([2., 8, 7])\n",
"a + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, True, True])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a < b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(False), tensor(False))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(a < b).all(), (a==b).all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.666666984558105"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(a + b).mean().item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1., 4., 9.],\n",
" [16., 25., 36.],\n",
" [49., 64., 81.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"m*m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0",
"output_type": "error",
"traceback": [
"\u001b[0;31m------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-add73c4f74e0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0"
]
}
],
"source": [
"n = tensor([[1., 2, 3], [4,5,6]])\n",
"m*n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
" for j in range(bc): c[i,j] = (a[i] * b[:,j]).sum()\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.7 ms ± 88.1 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
]
}
],
"source": [
"%timeit -n 20 t3 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Broadcasting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting with a scalar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ True, True, False])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = tensor([10., 6, -4])\n",
"a > 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.4652, -1.0989, -0.7326],\n",
" [-0.3663, 0.0000, 0.3663],\n",
" [ 0.7326, 1.0989, 1.4652]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"(m - 5) / 2.73"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting a vector to a matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([3, 3]), torch.Size([3]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"m.shape,c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[11., 22., 33.],\n",
" [14., 25., 36.],\n",
" [17., 28., 39.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m + c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[10., 20., 30.],\n",
" [10., 20., 30.],\n",
" [10., 20., 30.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c.expand_as(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" 10.0\n",
" 20.0\n",
" 30.0\n",
"[torch.FloatStorage of size 3]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = c.expand_as(m)\n",
"t.storage()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((0, 1), torch.Size([3, 3]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.stride(), t.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[11., 22., 33.],\n",
" [14., 25., 36.],\n",
" [17., 28., 39.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c + m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[11., 22., 33.],\n",
" [14., 25., 36.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6]])\n",
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1",
"output_type": "error",
"traceback": [
"\u001b[0;31m------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-25-64bbbad4d99c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m10.\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mc\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"
]
}
],
"source": [
"c = tensor([10.,20])\n",
"m = tensor([[1., 2, 3], [4,5,6]])\n",
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([3, 3]), torch.Size([3, 1]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = tensor([10.,20,30])\n",
"m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n",
"c = c.unsqueeze(1)\n",
"m.shape,c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[11., 12., 13.],\n",
" [24., 25., 26.],\n",
" [37., 38., 39.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c+m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" 10.0\n",
" 20.0\n",
" 30.0\n",
"[torch.FloatStorage of size 3]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = c.expand_as(m)\n",
"t.storage()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((1, 0), torch.Size([3, 3]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.stride(), t.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = tensor([10.,20,30])\n",
"c.shape, c.unsqueeze(0).shape,c.unsqueeze(1).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c.shape, c[None,:].shape,c[:,None].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 3]), torch.Size([3, 1]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c[None].shape,c[...,None].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b):\n",
" ar,ac = a.shape\n",
" br,bc = b.shape\n",
" assert ac==br\n",
" c = torch.zeros(ar, bc)\n",
" for i in range(ar):\n",
"# c[i,j] = (a[i,:] * b[:,j]).sum() # previous\n",
" c[i] = (a[i ].unsqueeze(-1) * b).sum(dim=0)\n",
" return c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"357 µs ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
]
}
],
"source": [
"%timeit -n 20 t4 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting Rules"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Einstein summation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"68.7 µs ± 4.06 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n"
]
}
],
"source": [
"%timeit -n 20 t5 = matmul(m1,m2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The forward and backward passes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining and initializing a layer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def lin(x, w, b): return x @ w + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w1 = torch.randn(100,50)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([200, 50])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l1 = lin(x, w1, b1)\n",
"l1.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.0019), tensor(10.1058))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l1.mean(), l1.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[nan, nan, nan, nan, nan],\n",
" [nan, nan, nan, nan, nan],\n",
" [nan, nan, nan, nan, nan],\n",
" [nan, nan, nan, nan, nan],\n",
" [nan, nan, nan, nan, nan]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ torch.randn(100,100)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ (torch.randn(100,100) * 0.01)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.7554, 0.6167, -0.1757, -1.5662, 0.5644],\n",
" [-0.1987, 0.6292, 0.3283, -1.1538, 0.5416],\n",
" [ 0.6106, 0.2556, -0.0618, -0.9463, 0.4445],\n",
" [ 0.4484, 0.7144, 0.1164, -0.8626, 0.4413],\n",
" [ 0.3463, 0.5930, 0.3375, -0.9486, 0.5643]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = x @ (torch.randn(100,100) * 0.1)\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.7042)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from math import sqrt\n",
"w1 = torch.randn(100,50) / sqrt(100)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1) / sqrt(50)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(-0.0050), tensor(1.0000))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l1 = lin(x, w1, b1)\n",
"l1.mean(),l1.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def relu(x): return x.clamp_min(0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.3961), tensor(0.5783))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l2 = relu(l1)\n",
"l2.mean(),l2.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000e+00, 1.9689e-08, 4.2820e-08, 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 1.6701e-08, 4.3501e-08, 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 1.0976e-08, 3.0411e-08, 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 1.8457e-08, 4.9469e-08, 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 1.9949e-08, 4.1643e-08, 0.0000e+00, 0.0000e+00]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = relu(x @ (torch.randn(100,100) * 0.1))\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.2871, 0.0000, 0.0000, 0.0000, 0.0026],\n",
" [0.4546, 0.0000, 0.0000, 0.0000, 0.0015],\n",
" [0.6178, 0.0000, 0.0000, 0.0180, 0.0079],\n",
" [0.3333, 0.0000, 0.0000, 0.0545, 0.0000],\n",
" [0.1940, 0.0000, 0.0000, 0.0000, 0.0096]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(200, 100)\n",
"for i in range(50): x = relu(x @ (torch.randn(100,100) * sqrt(2/100)))\n",
"x[0:5,0:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(200, 100)\n",
"y = torch.randn(200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w1 = torch.randn(100,50) * sqrt(2 / 100)\n",
"b1 = torch.zeros(50)\n",
"w2 = torch.randn(50,1) * sqrt(2 / 50)\n",
"b2 = torch.zeros(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.5661), tensor(0.8339))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l1 = lin(x, w1, b1)\n",
"l2 = relu(l1)\n",
"l2.mean(), l2.std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model(x):\n",
" l1 = lin(x, w1, b1)\n",
" l2 = relu(l1)\n",
" l3 = lin(l2, w2, b2)\n",
" return l3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([200, 1])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = model(x)\n",
"out.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = mse(out, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradients and backward pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mse_grad(inp, targ): \n",
" # grad of loss with respect to output of previous layer\n",
" inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def relu_grad(inp, out):\n",
" # grad of relu with respect to input activations\n",
" inp.g = (inp>0).float() * out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def lin_grad(inp, out, w, b):\n",
" # grad of matmul with respect to input\n",
" inp.g = out.g @ w.t()\n",
" w.g = inp.t() @ out.g\n",
" b.g = out.g.sum(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: SymPy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 2 sx$"
],
"text/plain": [
"2*sx"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sympy import symbols,diff\n",
"sx,sy = symbols('sx sy')\n",
"diff(sx**2, sx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def forward_and_backward(inp, targ):\n",
" # forward pass:\n",
" l1 = inp @ w1 + b1\n",
" l2 = relu(l1)\n",
" out = l2 @ w2 + b2\n",
" # we don't actually need the loss in backward!\n",
" loss = mse(out, targ)\n",
" \n",
" # backward pass:\n",
" mse_grad(out, targ)\n",
" lin_grad(l2, out, w2, b2)\n",
" relu_grad(l1, l2)\n",
" lin_grad(inp, l1, w1, b1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Refactor the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Relu():\n",
" def __call__(self, inp):\n",
" self.inp = inp\n",
" self.out = inp.clamp_min(0.)\n",
" return self.out\n",
" \n",
" def backward(self): self.inp.g = (self.inp>0).float() * self.out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Lin():\n",
" def __init__(self, w, b): self.w,self.b = w,b\n",
" \n",
" def __call__(self, inp):\n",
" self.inp = inp\n",
" self.out = inp@self.w + self.b\n",
" return self.out\n",
" \n",
" def backward(self):\n",
" self.inp.g = self.out.g @ self.w.t()\n",
" self.w.g = self.inp.t() @ self.out.g\n",
" self.b.g = self.out.g.sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Mse():\n",
" def __call__(self, inp, targ):\n",
" self.inp = inp\n",
" self.targ = targ\n",
" self.out = (inp.squeeze() - targ).pow(2).mean()\n",
" return self.out\n",
" \n",
" def backward(self):\n",
" x = (self.inp.squeeze()-self.targ).unsqueeze(-1)\n",
" self.inp.g = 2.*x/self.targ.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model():\n",
" def __init__(self, w1, b1, w2, b2):\n",
" self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n",
" self.loss = Mse()\n",
" \n",
" def __call__(self, x, targ):\n",
" for l in self.layers: x = l(x)\n",
" return self.loss(x, targ)\n",
" \n",
" def backward(self):\n",
" self.loss.backward()\n",
" for l in reversed(self.layers): l.backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Model(w1, b1, w2, b2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = model(x, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Going to PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LayerFunction():\n",
" def __call__(self, *args):\n",
" self.args = args\n",
" self.out = self.forward(*args)\n",
" return self.out\n",
" \n",
" def forward(self): raise Exception('not implemented')\n",
" def bwd(self): raise Exception('not implemented')\n",
" def backward(self): self.bwd(self.out, *self.args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Relu(LayerFunction):\n",
" def forward(self, inp): return inp.clamp_min(0.)\n",
" def bwd(self, out, inp): inp.g = (inp>0).float() * out.g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Lin(LayerFunction):\n",
" def __init__(self, w, b): self.w,self.b = w,b\n",
" \n",
" def forward(self, inp): return inp@self.w + self.b\n",
" \n",
" def bwd(self, out, inp):\n",
" inp.g = out.g @ self.w.t()\n",
" self.w.g = self.inp.t() @ self.out.g\n",
" self.b.g = out.g.sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Mse(LayerFunction):\n",
" def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n",
" def bwd(self, out, inp, targ): \n",
" inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.autograd import Function\n",
"\n",
"class MyRelu(Function):\n",
" @staticmethod\n",
" def forward(ctx, i):\n",
" result = i.clamp_min(0.)\n",
" ctx.save_for_backward(i)\n",
" return result\n",
" \n",
" @staticmethod\n",
" def backward(ctx, grad_output):\n",
" i, = ctx.saved_tensors\n",
" return grad_output * (i>0).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class LinearLayer(nn.Module):\n",
" def __init__(self, n_in, n_out):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.randn(n_out, n_in) * sqrt(2/n_in))\n",
" self.bias = nn.Parameter(torch.zeros(n_out))\n",
" \n",
" def forward(self, x): return x @ self.weight.t() + self.bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([2, 10]), torch.Size([2]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin = LinearLayer(10,2)\n",
"p1,p2 = lin.parameters()\n",
"p1.shape,p2.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" super().__init__()\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
" self.loss = mse\n",
" \n",
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model(Module):\n",
" def __init__(self, n_in, nh, n_out):\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n",
" self.loss = mse\n",
" \n",
" def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Things to remember"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}