{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [], "source": [ "#hide\n", "from fastai.gen_doc.nbdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# A neural net from the foundations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A neural net layer from scratch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modeling a neuron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Matrix multiplication from scratch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import tensor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def matmul(a,b):\n", " ar,ac = a.shape # n_rows * n_cols\n", " br,bc = b.shape\n", " assert ac==br\n", " c = torch.zeros(ar, bc)\n", " for i in range(ar):\n", " for j in range(bc):\n", " for k in range(ac): c[i,j] += a[i,k] * b[k,j]\n", " return c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m1 = torch.randn(5,28*28)\n", "m2 = torch.randn(784,10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.15 s, sys: 4.09 ms, total: 1.15 s\n", "Wall time: 1.15 s\n" ] } ], "source": [ "%time t1=matmul(m1, m2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14 µs ± 8.95 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n" ] } ], "source": [ "%timeit -n 20 t2=m1@m2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([12., 14., 3.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor([10., 6, -4])\n", "b = tensor([2., 8, 7])\n", "a + b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([False, True, True])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a < b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(False), tensor(False))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(a < b).all(), (a==b).all()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.666666984558105" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(a + b).mean().item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1., 4., 9.],\n", " [16., 25., 36.],\n", " [49., 64., 81.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n", "m*m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0", "output_type": "error", "traceback": [ "\u001b[0;31m------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0" ] } ], "source": [ "n = tensor([[1., 2, 3], [4,5,6]])\n", "m*n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def matmul(a,b):\n", " ar,ac = a.shape\n", " br,bc = b.shape\n", " assert ac==br\n", " c = torch.zeros(ar, bc)\n", " for i in range(ar):\n", " for j in range(bc): c[i,j] = (a[i] * b[:,j]).sum()\n", " return c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.7 ms ± 88.1 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n" ] } ], "source": [ "%timeit -n 20 t3 = matmul(m1,m2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Broadcasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting with a scalar" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ True, True, False])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor([10., 6, -4])\n", "a > 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.4652, -1.0989, -0.7326],\n", " [-0.3663, 0.0000, 0.3663],\n", " [ 0.7326, 1.0989, 1.4652]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n", "(m - 5) / 2.73" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting a vector to a matrix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3, 3]), torch.Size([3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = tensor([10.,20,30])\n", "m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n", "m.shape,c.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[11., 22., 33.],\n", " [14., 25., 36.],\n", " [17., 28., 39.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[10., 20., 30.],\n", " [10., 20., 30.],\n", " [10., 20., 30.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.expand_as(m)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ " 10.0\n", " 20.0\n", " 30.0\n", "[torch.FloatStorage of size 3]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = c.expand_as(m)\n", "t.storage()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((0, 1), torch.Size([3, 3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.stride(), t.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[11., 22., 33.],\n", " [14., 25., 36.],\n", " [17., 28., 39.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c + m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[11., 22., 33.],\n", " [14., 25., 36.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = tensor([10.,20,30])\n", "m = tensor([[1., 2, 3], [4,5,6]])\n", "c+m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1", "output_type": "error", "traceback": [ "\u001b[0;31m------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m10.\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mc\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" ] } ], "source": [ "c = tensor([10.,20])\n", "m = tensor([[1., 2, 3], [4,5,6]])\n", "c+m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3, 3]), torch.Size([3, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = tensor([10.,20,30])\n", "m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])\n", "c = c.unsqueeze(1)\n", "m.shape,c.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[11., 12., 13.],\n", " [24., 25., 26.],\n", " [37., 38., 39.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c+m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ " 10.0\n", " 20.0\n", " 30.0\n", "[torch.FloatStorage of size 3]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = c.expand_as(m)\n", "t.storage()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1, 0), torch.Size([3, 3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.stride(), t.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = tensor([10.,20,30])\n", "c.shape, c.unsqueeze(0).shape,c.unsqueeze(1).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.shape, c[None,:].shape,c[:,None].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3]), torch.Size([3, 1]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c[None].shape,c[...,None].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def matmul(a,b):\n", " ar,ac = a.shape\n", " br,bc = b.shape\n", " assert ac==br\n", " c = torch.zeros(ar, bc)\n", " for i in range(ar):\n", "# c[i,j] = (a[i,:] * b[:,j]).sum() # previous\n", " c[i] = (a[i ].unsqueeze(-1) * b).sum(dim=0)\n", " return c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "357 µs ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n" ] } ], "source": [ "%timeit -n 20 t4 = matmul(m1,m2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting Rules" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Einstein summation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "68.7 µs ± 4.06 µs per loop (mean ± std. dev. of 7 runs, 20 loops each)\n" ] } ], "source": [ "%timeit -n 20 t5 = matmul(m1,m2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The forward and backward passes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining and initializing a layer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def lin(x, w, b): return x @ w + b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(200, 100)\n", "y = torch.randn(200)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1 = torch.randn(100,50)\n", "b1 = torch.zeros(50)\n", "w2 = torch.randn(50,1)\n", "b2 = torch.zeros(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([200, 50])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l1 = lin(x, w1, b1)\n", "l1.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0019), tensor(10.1058))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l1.mean(), l1.std()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[nan, nan, nan, nan, nan],\n", " [nan, nan, nan, nan, nan],\n", " [nan, nan, nan, nan, nan],\n", " [nan, nan, nan, nan, nan],\n", " [nan, nan, nan, nan, nan]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(200, 100)\n", "for i in range(50): x = x @ torch.randn(100,100)\n", "x[0:5,0:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(200, 100)\n", "for i in range(50): x = x @ (torch.randn(100,100) * 0.01)\n", "x[0:5,0:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.7554, 0.6167, -0.1757, -1.5662, 0.5644],\n", " [-0.1987, 0.6292, 0.3283, -1.1538, 0.5416],\n", " [ 0.6106, 0.2556, -0.0618, -0.9463, 0.4445],\n", " [ 0.4484, 0.7144, 0.1164, -0.8626, 0.4413],\n", " [ 0.3463, 0.5930, 0.3375, -0.9486, 0.5643]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(200, 100)\n", "for i in range(50): x = x @ (torch.randn(100,100) * 0.1)\n", "x[0:5,0:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.7042)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.std()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(200, 100)\n", "y = torch.randn(200)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from math import sqrt\n", "w1 = torch.randn(100,50) / sqrt(100)\n", "b1 = torch.zeros(50)\n", "w2 = torch.randn(50,1) / sqrt(50)\n", "b2 = torch.zeros(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0050), tensor(1.0000))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l1 = lin(x, w1, b1)\n", "l1.mean(),l1.std()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def relu(x): return x.clamp_min(0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.3961), tensor(0.5783))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l2 = relu(l1)\n", "l2.mean(),l2.std()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0000e+00, 1.9689e-08, 4.2820e-08, 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 1.6701e-08, 4.3501e-08, 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 1.0976e-08, 3.0411e-08, 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 1.8457e-08, 4.9469e-08, 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 1.9949e-08, 4.1643e-08, 0.0000e+00, 0.0000e+00]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(200, 100)\n", "for i in range(50): x = relu(x @ (torch.randn(100,100) * 0.1))\n", "x[0:5,0:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.2871, 0.0000, 0.0000, 0.0000, 0.0026],\n", " [0.4546, 0.0000, 0.0000, 0.0000, 0.0015],\n", " [0.6178, 0.0000, 0.0000, 0.0180, 0.0079],\n", " [0.3333, 0.0000, 0.0000, 0.0545, 0.0000],\n", " [0.1940, 0.0000, 0.0000, 0.0000, 0.0096]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(200, 100)\n", "for i in range(50): x = relu(x @ (torch.randn(100,100) * sqrt(2/100)))\n", "x[0:5,0:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(200, 100)\n", "y = torch.randn(200)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w1 = torch.randn(100,50) * sqrt(2 / 100)\n", "b1 = torch.zeros(50)\n", "w2 = torch.randn(50,1) * sqrt(2 / 50)\n", "b2 = torch.zeros(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.5661), tensor(0.8339))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l1 = lin(x, w1, b1)\n", "l2 = relu(l1)\n", "l2.mean(), l2.std()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def model(x):\n", " l1 = lin(x, w1, b1)\n", " l2 = relu(l1)\n", " l3 = lin(l2, w2, b2)\n", " return l3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([200, 1])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = model(x)\n", "out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss = mse(out, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradients and backward pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse_grad(inp, targ): \n", " # grad of loss with respect to output of previous layer\n", " inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def relu_grad(inp, out):\n", " # grad of relu with respect to input activations\n", " inp.g = (inp>0).float() * out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def lin_grad(inp, out, w, b):\n", " # grad of matmul with respect to input\n", " inp.g = out.g @ w.t()\n", " w.g = inp.t() @ out.g\n", " b.g = out.g.sum(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sidebar: SymPy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle 2 sx$" ], "text/plain": [ "2*sx" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sympy import symbols,diff\n", "sx,sy = symbols('sx sy')\n", "diff(sx**2, sx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### End sidebar" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def forward_and_backward(inp, targ):\n", " # forward pass:\n", " l1 = inp @ w1 + b1\n", " l2 = relu(l1)\n", " out = l2 @ w2 + b2\n", " # we don't actually need the loss in backward!\n", " loss = mse(out, targ)\n", " \n", " # backward pass:\n", " mse_grad(out, targ)\n", " lin_grad(l2, out, w2, b2)\n", " relu_grad(l1, l2)\n", " lin_grad(inp, l1, w1, b1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactor the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Relu():\n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp.clamp_min(0.)\n", " return self.out\n", " \n", " def backward(self): self.inp.g = (self.inp>0).float() * self.out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lin():\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def __call__(self, inp):\n", " self.inp = inp\n", " self.out = inp@self.w + self.b\n", " return self.out\n", " \n", " def backward(self):\n", " self.inp.g = self.out.g @ self.w.t()\n", " self.w.g = self.inp.t() @ self.out.g\n", " self.b.g = self.out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Mse():\n", " def __call__(self, inp, targ):\n", " self.inp = inp\n", " self.targ = targ\n", " self.out = (inp.squeeze() - targ).pow(2).mean()\n", " return self.out\n", " \n", " def backward(self):\n", " x = (self.inp.squeeze()-self.targ).unsqueeze(-1)\n", " self.inp.g = 2.*x/self.targ.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model():\n", " def __init__(self, w1, b1, w2, b2):\n", " self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]\n", " self.loss = Mse()\n", " \n", " def __call__(self, x, targ):\n", " for l in self.layers: x = l(x)\n", " return self.loss(x, targ)\n", " \n", " def backward(self):\n", " self.loss.backward()\n", " for l in reversed(self.layers): l.backward()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Model(w1, b1, w2, b2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss = model(x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Going to PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LayerFunction():\n", " def __call__(self, *args):\n", " self.args = args\n", " self.out = self.forward(*args)\n", " return self.out\n", " \n", " def forward(self): raise Exception('not implemented')\n", " def bwd(self): raise Exception('not implemented')\n", " def backward(self): self.bwd(self.out, *self.args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Relu(LayerFunction):\n", " def forward(self, inp): return inp.clamp_min(0.)\n", " def bwd(self, out, inp): inp.g = (inp>0).float() * out.g" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lin(LayerFunction):\n", " def __init__(self, w, b): self.w,self.b = w,b\n", " \n", " def forward(self, inp): return inp@self.w + self.b\n", " \n", " def bwd(self, out, inp):\n", " inp.g = out.g @ self.w.t()\n", " self.w.g = self.inp.t() @ self.out.g\n", " self.b.g = out.g.sum(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Mse(LayerFunction):\n", " def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()\n", " def bwd(self, out, inp, targ): \n", " inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.autograd import Function\n", "\n", "class MyRelu(Function):\n", " @staticmethod\n", " def forward(ctx, i):\n", " result = i.clamp_min(0.)\n", " ctx.save_for_backward(i)\n", " return result\n", " \n", " @staticmethod\n", " def backward(ctx, grad_output):\n", " i, = ctx.saved_tensors\n", " return grad_output * (i>0).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class LinearLayer(nn.Module):\n", " def __init__(self, n_in, n_out):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.randn(n_out, n_in) * sqrt(2/n_in))\n", " self.bias = nn.Parameter(torch.zeros(n_out))\n", " \n", " def forward(self, x): return x @ self.weight.t() + self.bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 10]), torch.Size([2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lin = LinearLayer(10,2)\n", "p1,p2 = lin.parameters()\n", "p1.shape,p2.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, n_in, nh, n_out):\n", " super().__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n", " self.loss = mse\n", " \n", " def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model(Module):\n", " def __init__(self, n_in, nh, n_out):\n", " self.layers = nn.Sequential(\n", " nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out))\n", " self.loss = mse\n", " \n", " def forward(self, x, targ): return self.loss(self.layers(x).squeeze(), targ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Things to remember" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Questionnaire" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Further research" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }