fastbook/clean/04_mnist_basics.ipynb

1654 lines
33 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
2020-09-03 22:51:00 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
2020-08-21 19:36:27 +00:00
"from fastai.vision.all import *\n",
2020-09-03 22:51:00 +00:00
"from fastbook import *\n",
2020-03-06 18:19:03 +00:00
"\n",
"matplotlib.rc('image', cmap='Greys')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"# Under the Hood: Training a Digit Classifier"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Pixels: The Foundations of Computer Vision"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Sidebar: Tenacity and Deep Learning"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## End sidebar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.MNIST_SAMPLE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"(path/'train').ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"threes = (path/'train'/'3').ls().sorted()\n",
"sevens = (path/'train'/'7').ls().sorted()\n",
"threes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"im3_path = threes[1]\n",
"im3 = Image.open(im3_path)\n",
"im3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"array(im3)[4:10,4:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tensor(im3)[4:10,4:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"im3_t = tensor(im3)\n",
"df = pd.DataFrame(im3_t[4:15,4:22])\n",
"df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## First Try: Pixel Similarity"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"seven_tensors = [tensor(Image.open(o)) for o in sevens]\n",
"three_tensors = [tensor(Image.open(o)) for o in threes]\n",
"len(three_tensors),len(seven_tensors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"show_image(three_tensors[1]);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"stacked_sevens = torch.stack(seven_tensors).float()/255\n",
"stacked_threes = torch.stack(three_tensors).float()/255\n",
"stacked_threes.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"len(stacked_threes.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"stacked_threes.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"mean3 = stacked_threes.mean(0)\n",
"show_image(mean3);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"mean7 = stacked_sevens.mean(0)\n",
"show_image(mean7);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"a_3 = stacked_threes[1]\n",
"show_image(a_3);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"dist_3_abs = (a_3 - mean3).abs().mean()\n",
"dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n",
"dist_3_abs,dist_3_sqr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"dist_7_abs = (a_3 - mean7).abs().mean()\n",
"dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n",
"dist_7_abs,dist_7_sqr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"F.l1_loss(a_3.float(),mean7), F.mse_loss(a_3,mean7).sqrt()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### NumPy Arrays and PyTorch Tensors"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = [[1,2,3],[4,5,6]]\n",
"arr = array (data)\n",
"tns = tensor(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"arr # numpy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns # pytorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns[:,1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns[1,1:3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns+1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns.type()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tns*1.5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Computing Metrics Using Broadcasting"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"valid_3_tens = torch.stack([tensor(Image.open(o)) \n",
" for o in (path/'valid'/'3').ls()])\n",
"valid_3_tens = valid_3_tens.float()/255\n",
"valid_7_tens = torch.stack([tensor(Image.open(o)) \n",
" for o in (path/'valid'/'7').ls()])\n",
"valid_7_tens = valid_7_tens.float()/255\n",
"valid_3_tens.shape,valid_7_tens.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))\n",
"mnist_distance(a_3, mean3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"valid_3_dist = mnist_distance(valid_3_tens, mean3)\n",
"valid_3_dist, valid_3_dist.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"tensor([1,2,3]) + tensor([1,1,1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"(valid_3_tens-mean3).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_3(x): return mnist_distance(x,mean3) < mnist_distance(x,mean7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"is_3(a_3), is_3(a_3).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"is_3(valid_3_tens)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"accuracy_3s = is_3(valid_3_tens).float() .mean()\n",
"accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()\n",
"\n",
"accuracy_3s,accuracy_7s,(accuracy_3s+accuracy_7s)/2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-03-17 19:15:55 +00:00
"## Stochastic Gradient Descent (SGD)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
2020-09-03 22:58:27 +00:00
"metadata": {},
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"gv('''\n",
"init->predict->loss->gradient->step->stop\n",
"step->predict[label=repeat]\n",
"''')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def f(x): return x**2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"plot_function(f, 'x', 'x**2')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"plot_function(f, 'x', 'x**2')\n",
"plt.scatter(-1.5, f(-1.5), color='red');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Calculating Gradients"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xt = tensor(3.).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"yt = f(xt)\n",
"yt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"yt.backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"xt.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"xt = tensor([3.,4.,10.]).requires_grad_()\n",
"xt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"def f(x): return (x**2).sum()\n",
"\n",
"yt = f(xt)\n",
"yt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"yt.backward()\n",
"xt.grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Stepping With a Learning Rate"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### An End-to-End SGD Example"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"time = torch.arange(0,20).float(); time"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1\n",
"plt.scatter(time,speed);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def f(t, params):\n",
" a,b,c = params\n",
" return a*(t**2) + (b*t) + c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-11-29 18:40:59 +00:00
"def mse(preds, targets): return ((preds-targets)**2).mean().sqrt()"
2020-03-06 18:19:03 +00:00
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 1: Initialize the parameters"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"params = torch.randn(3).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"orig_params = params.clone()"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 2: Calculate the predictions"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preds = f(time, params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_preds(preds, ax=None):\n",
" if ax is None: ax=plt.subplots()[1]\n",
" ax.scatter(time, speed)\n",
" ax.scatter(time, to_np(preds), color='red')\n",
" ax.set_ylim(-300,100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"show_preds(preds)"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 3: Calculate the loss"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"loss = mse(preds, speed)\n",
"loss"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 4: Calculate the gradients"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"loss.backward()\n",
"params.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"params.grad * 1e-5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"params"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 5: Step the weights. "
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr = 1e-5\n",
"params.data -= lr * params.grad.data\n",
"params.grad = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"preds = f(time,params)\n",
"mse(preds, speed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"show_preds(preds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def apply_step(params, prn=True):\n",
" preds = f(time, params)\n",
" loss = mse(preds, speed)\n",
" loss.backward()\n",
" params.data -= lr * params.grad.data\n",
" params.grad = None\n",
" if prn: print(loss.item())\n",
" return preds"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Step 6: Repeat the process "
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"for i in range(10): apply_step(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"params = orig_params.detach().requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"_,axs = plt.subplots(1,4,figsize=(12,3))\n",
"for ax in axs: show_preds(apply_step(params, False), ax)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"#### Step 7: stop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Summarizing Gradient Descent"
2020-03-06 18:19:03 +00:00
]
},
2020-04-15 12:21:02 +00:00
{
"cell_type": "code",
"execution_count": null,
2020-09-03 22:58:27 +00:00
"metadata": {},
"outputs": [],
2020-04-15 12:21:02 +00:00
"source": [
"gv('''\n",
"init->predict->loss->gradient->step->stop\n",
"step->predict[label=repeat]\n",
"''')"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## The MNIST Loss Function"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-04-15 12:21:02 +00:00
"source": [
"train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n",
"train_x.shape,train_y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-04-15 12:21:02 +00:00
"source": [
"dset = list(zip(train_x,train_y))\n",
"x,y = dset[0]\n",
"x.shape,y"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n",
"valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n",
"valid_dset = list(zip(valid_x,valid_y))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weights = init_params((28*28,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bias = init_params(1)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"(train_x[0]*weights.T).sum() + bias"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"def linear1(xb): return xb@weights + bias\n",
"preds = linear1(train_x)\n",
"preds"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"corrects = (preds>0.0).float() == train_y\n",
"corrects"
2020-03-06 18:19:03 +00:00
]
},
{
2020-04-15 12:21:02 +00:00
"cell_type": "code",
"execution_count": null,
2020-03-06 18:19:03 +00:00
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"corrects.float().mean().item()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"weights[0] *= 1.0001"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"preds = linear1(train_x)\n",
"((preds>0.0).float() == train_y).float().mean().item()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"trgts = tensor([1,0,1])\n",
"prds = tensor([0.9, 0.4, 0.2])"
2020-03-06 18:19:03 +00:00
]
},
{
2020-04-15 12:21:02 +00:00
"cell_type": "code",
"execution_count": null,
2020-03-06 18:19:03 +00:00
"metadata": {},
2020-04-15 12:21:02 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"def mnist_loss(predictions, targets):\n",
" return torch.where(targets==1, 1-predictions, predictions).mean()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"torch.where(trgts==1, 1-prds, prds)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"mnist_loss(prds,trgts)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"mnist_loss(tensor([0.9, 0.4, 0.8]),trgts)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-04-15 12:21:02 +00:00
"### Sigmoid"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"def sigmoid(x): return 1/(1+torch.exp(-x))"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mnist_loss(predictions, targets):\n",
" predictions = predictions.sigmoid()\n",
" return torch.where(targets==1, 1-predictions, predictions).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### SGD and Mini-Batches"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"coll = range(15)\n",
"dl = DataLoader(coll, batch_size=5, shuffle=True)\n",
"list(dl)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"ds = L(enumerate(string.ascii_lowercase))\n",
"ds"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"dl = DataLoader(ds, batch_size=6, shuffle=True)\n",
"list(dl)"
2020-03-06 18:19:03 +00:00
]
},
{
2020-04-15 12:21:02 +00:00
"cell_type": "markdown",
2020-03-06 18:19:03 +00:00
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Putting It All Together"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2020-04-15 12:21:02 +00:00
"weights = init_params((28*28,1))\n",
2020-03-06 18:19:03 +00:00
"bias = init_params(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
2020-04-15 12:21:02 +00:00
"dl = DataLoader(dset, batch_size=256)\n",
"xb,yb = first(dl)\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"valid_dl = DataLoader(valid_dset, batch_size=256)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"batch = train_x[:4]\n",
"batch.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"preds = linear1(batch)\n",
"preds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"loss = mnist_loss(preds, train_y[:4])\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"loss.backward()\n",
"weights.grad.shape,weights.grad.mean(),bias.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def calc_grad(xb, yb, model):\n",
" preds = model(xb)\n",
" loss = mnist_loss(preds, yb)\n",
" loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"calc_grad(batch, train_y[:4], linear1)\n",
"weights.grad.mean(),bias.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"calc_grad(batch, train_y[:4], linear1)\n",
"weights.grad.mean(),bias.grad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weights.grad.zero_()\n",
"bias.grad.zero_();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model, lr, params):\n",
" for xb,yb in dl:\n",
" calc_grad(xb, yb, model)\n",
" for p in params:\n",
" p.data -= p.grad*lr\n",
" p.grad.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"(preds>0.0).float() == train_y[:4]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def batch_accuracy(xb, yb):\n",
" preds = xb.sigmoid()\n",
" correct = (preds>0.5) == yb\n",
" return correct.float().mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"batch_accuracy(linear1(batch), train_y[:4])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def validate_epoch(model):\n",
" accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n",
" return round(torch.stack(accs).mean().item(), 4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"validate_epoch(linear1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"lr = 1.\n",
"params = weights,bias\n",
"train_epoch(linear1, lr, params)\n",
"validate_epoch(linear1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"for i in range(20):\n",
" train_epoch(linear1, lr, params)\n",
" print(validate_epoch(linear1), end=' ')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Creating an Optimizer"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"linear_model = nn.Linear(28*28,1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"w,b = linear_model.parameters()\n",
"w.shape,b.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class BasicOptim:\n",
" def __init__(self,params,lr): self.params,self.lr = list(params),lr\n",
"\n",
" def step(self, *args, **kwargs):\n",
" for p in self.params: p.data -= p.grad.data * self.lr\n",
"\n",
" def zero_grad(self, *args, **kwargs):\n",
" for p in self.params: p.grad = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt = BasicOptim(linear_model.parameters(), lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(model):\n",
" for xb,yb in dl:\n",
" calc_grad(xb, yb, model)\n",
" opt.step()\n",
" opt.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"validate_epoch(linear_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, epochs):\n",
" for i in range(epochs):\n",
" train_epoch(model)\n",
" print(validate_epoch(model), end=' ')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"train_model(linear_model, 20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"linear_model = nn.Linear(28*28,1)\n",
"opt = SGD(linear_model.parameters(), lr)\n",
"train_model(linear_model, 20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = DataLoaders(dl, valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n",
" loss_func=mnist_loss, metrics=batch_accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"learn.fit(10, lr=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Adding a Nonlinearity"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def simple_net(xb): \n",
" res = xb@w1 + b1\n",
" res = res.max(tensor(0.0))\n",
" res = res@w2 + b2\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w1 = init_params((28*28,30))\n",
"b1 = init_params(30)\n",
"w2 = init_params((30,1))\n",
"b2 = init_params(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"plot_function(F.relu)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"simple_net = nn.Sequential(\n",
" nn.Linear(28*28,30),\n",
" nn.ReLU(),\n",
" nn.Linear(30,1)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, simple_net, opt_func=SGD,\n",
" loss_func=mnist_loss, metrics=batch_accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"learn.fit(40, 0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"plt.plot(L(learn.recorder.values).itemgot(2));"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"learn.recorder.values[-1][2]"
]
},
2020-05-14 12:18:31 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Going Deeper"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2020-09-03 22:58:27 +00:00
"outputs": [],
2020-03-06 18:19:03 +00:00
"source": [
"dls = ImageDataLoaders.from_folder(path)\n",
"learn = cnn_learner(dls, resnet18, pretrained=False,\n",
" loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(1, 0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Jargon Recap"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. How is a grayscale image represented on a computer? How about a color image?\n",
2020-03-18 00:34:07 +00:00
"1. How are the files and folders in the `MNIST_SAMPLE` dataset structured? Why?\n",
"1. Explain how the \"pixel similarity\" approach to classifying digits works.\n",
"1. What is a list comprehension? Create one now that selects odd numbers from a list and doubles them.\n",
2020-05-14 12:18:31 +00:00
"1. What is a \"rank-3 tensor\"?\n",
2020-03-18 00:34:07 +00:00
"1. What is the difference between tensor rank and shape? How do you get the rank from the shape?\n",
"1. What are RMSE and L1 norm?\n",
"1. How can you apply a calculation on thousands of numbers at once, many thousands of times faster than a Python loop?\n",
2020-05-19 23:56:41 +00:00
"1. Create a 3×3 tensor or array containing the numbers from 1 to 9. Double it. Select the bottom-right four numbers.\n",
2020-03-18 00:34:07 +00:00
"1. What is broadcasting?\n",
"1. Are metrics generally calculated using the training set, or the validation set? Why?\n",
"1. What is SGD?\n",
2020-05-14 12:18:31 +00:00
"1. Why does SGD use mini-batches?\n",
"1. What are the seven steps in SGD for machine learning?\n",
2020-03-18 00:34:07 +00:00
"1. How do we initialize the weights in a model?\n",
"1. What is \"loss\"?\n",
"1. Why can't we always use a high learning rate?\n",
"1. What is a \"gradient\"?\n",
"1. Do you need to know how to calculate gradients yourself?\n",
"1. Why can't we use accuracy as a loss function?\n",
"1. Draw the sigmoid function. What is special about its shape?\n",
2020-05-14 12:18:31 +00:00
"1. What is the difference between a loss function and a metric?\n",
2020-03-18 00:34:07 +00:00
"1. What is the function to calculate new weights using a learning rate?\n",
"1. What does the `DataLoader` class do?\n",
2020-05-14 12:18:31 +00:00
"1. Write pseudocode showing the basic steps taken in each epoch for SGD.\n",
"1. Create a function that, if passed two arguments `[1,2,3,4]` and `'abcd'`, returns `[(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')]`. What is special about that output data structure?\n",
2020-03-18 00:34:07 +00:00
"1. What does `view` do in PyTorch?\n",
"1. What are the \"bias\" parameters in a neural network? Why do we need them?\n",
2020-05-14 12:18:31 +00:00
"1. What does the `@` operator do in Python?\n",
2020-03-18 00:34:07 +00:00
"1. What does the `backward` method do?\n",
"1. Why do we have to zero the gradients?\n",
"1. What information do we have to pass to `Learner`?\n",
2020-05-14 12:18:31 +00:00
"1. Show Python or pseudocode for the basic steps of a training loop.\n",
2020-03-18 00:34:07 +00:00
"1. What is \"ReLU\"? Draw a plot of it for values from `-2` to `+2`.\n",
"1. What is an \"activation function\"?\n",
"1. What's the difference between `F.relu` and `nn.ReLU`?\n",
"1. The universal approximation theorem shows that any function can be approximated as closely as needed using just one nonlinearity. So why do we normally use more?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Create your own implementation of `Learner` from scratch, based on the training loop shown in this chapter.\n",
2020-05-14 12:18:31 +00:00
"1. Complete all the steps in this chapter using the full MNIST datasets (that is, for all digits, not just 3s and 7s). This is a significant project and will take you quite a bit of time to complete! You'll need to do some of your own research to figure out how to overcome some obstacles you'll meet on the way."
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
2020-03-22 20:03:50 +00:00
"jupytext": {
"split_at_heading": true
},
2020-03-06 18:19:03 +00:00
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
2020-03-17 19:15:55 +00:00
"nbformat_minor": 4
2020-03-06 18:19:03 +00:00
}