{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "!pip install -Uqq fastbook\n",
    "import fastbook\n",
    "fastbook.setup_book()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai.vision.all import *\n",
    "from fastbook import *\n",
    "\n",
    "matplotlib.rc('image', cmap='Greys')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Under the Hood: Training a Digit Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pixels: The Foundations of Computer Vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sidebar: Tenacity and Deep Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## End sidebar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.MNIST_SAMPLE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "Path.BASE_PATH = path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(path/'train').ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "threes = (path/'train'/'3').ls().sorted()\n",
    "sevens = (path/'train'/'7').ls().sorted()\n",
    "threes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "im3_path = threes[1]\n",
    "im3 = Image.open(im3_path)\n",
    "im3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "array(im3)[4:10,4:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor(im3)[4:10,4:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "im3_t = tensor(im3)\n",
    "df = pd.DataFrame(im3_t[4:15,4:22])\n",
    "df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First Try: Pixel Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seven_tensors = [tensor(Image.open(o)) for o in sevens]\n",
    "three_tensors = [tensor(Image.open(o)) for o in threes]\n",
    "len(three_tensors),len(seven_tensors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_image(three_tensors[1]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stacked_sevens = torch.stack(seven_tensors).float()/255\n",
    "stacked_threes = torch.stack(three_tensors).float()/255\n",
    "stacked_threes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(stacked_threes.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stacked_threes.ndim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean3 = stacked_threes.mean(0)\n",
    "show_image(mean3);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean7 = stacked_sevens.mean(0)\n",
    "show_image(mean7);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_3 = stacked_threes[1]\n",
    "show_image(a_3);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_3_abs = (a_3 - mean3).abs().mean()\n",
    "dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n",
    "dist_3_abs,dist_3_sqr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_7_abs = (a_3 - mean7).abs().mean()\n",
    "dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n",
    "dist_7_abs,dist_7_sqr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.l1_loss(a_3.float(),mean7), F.mse_loss(a_3,mean7).sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NumPy Arrays and PyTorch Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = [[1,2,3],[4,5,6]]\n",
    "arr = array (data)\n",
    "tns = tensor(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr  # numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns  # pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns[:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns[1,1:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tns*1.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing Metrics Using Broadcasting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_3_tens = torch.stack([tensor(Image.open(o)) \n",
    "                            for o in (path/'valid'/'3').ls()])\n",
    "valid_3_tens = valid_3_tens.float()/255\n",
    "valid_7_tens = torch.stack([tensor(Image.open(o)) \n",
    "                            for o in (path/'valid'/'7').ls()])\n",
    "valid_7_tens = valid_7_tens.float()/255\n",
    "valid_3_tens.shape,valid_7_tens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))\n",
    "mnist_distance(a_3, mean3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_3_dist = mnist_distance(valid_3_tens, mean3)\n",
    "valid_3_dist, valid_3_dist.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor([1,2,3]) + tensor([1,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(valid_3_tens-mean3).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_3(x): return mnist_distance(x,mean3) < mnist_distance(x,mean7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_3(a_3), is_3(a_3).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_3(valid_3_tens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_3s =      is_3(valid_3_tens).float() .mean()\n",
    "accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()\n",
    "\n",
    "accuracy_3s,accuracy_7s,(accuracy_3s+accuracy_7s)/2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stochastic Gradient Descent (SGD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gv('''\n",
    "init->predict->loss->gradient->step->stop\n",
    "step->predict[label=repeat]\n",
    "''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x): return x**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(f, 'x', 'x**2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(f, 'x', 'x**2')\n",
    "plt.scatter(-1.5, f(-1.5), color='red');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculating Gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xt = tensor(3.).requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "yt = f(xt)\n",
    "yt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "yt.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xt.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xt = tensor([3.,4.,10.]).requires_grad_()\n",
    "xt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x): return (x**2).sum()\n",
    "\n",
    "yt = f(xt)\n",
    "yt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "yt.backward()\n",
    "xt.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stepping With a Learning Rate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An End-to-End SGD Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time = torch.arange(0,20).float(); time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1\n",
    "plt.scatter(time,speed);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(t, params):\n",
    "    a,b,c = params\n",
    "    return a*(t**2) + (b*t) + c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mse(preds, targets): return ((preds-targets)**2).mean().sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 1: Initialize the parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = torch.randn(3).requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "orig_params = params.clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2: Calculate the predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = f(time, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_preds(preds, ax=None):\n",
    "    if ax is None: ax=plt.subplots()[1]\n",
    "    ax.scatter(time, speed)\n",
    "    ax.scatter(time, to_np(preds), color='red')\n",
    "    ax.set_ylim(-300,100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_preds(preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 3: Calculate the loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = mse(preds, speed)\n",
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: Calculate the gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss.backward()\n",
    "params.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.grad * 1e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 5: Step the weights. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-5\n",
    "params.data -= lr * params.grad.data\n",
    "params.grad = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = f(time,params)\n",
    "mse(preds, speed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_preds(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_step(params, prn=True):\n",
    "    preds = f(time, params)\n",
    "    loss = mse(preds, speed)\n",
    "    loss.backward()\n",
    "    params.data -= lr * params.grad.data\n",
    "    params.grad = None\n",
    "    if prn: print(loss.item())\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 6: Repeat the process "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10): apply_step(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "params = orig_params.detach().requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,axs = plt.subplots(1,4,figsize=(12,3))\n",
    "for ax in axs: show_preds(apply_step(params, False), ax)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 7: stop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Summarizing Gradient Descent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gv('''\n",
    "init->predict->loss->gradient->step->stop\n",
    "step->predict[label=repeat]\n",
    "''')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The MNIST Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n",
    "train_x.shape,train_y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dset = list(zip(train_x,train_y))\n",
    "x,y = dset[0]\n",
    "x.shape,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n",
    "valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n",
    "valid_dset = list(zip(valid_x,valid_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = init_params((28*28,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = init_params(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_x[0]*weights.T).sum() + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear1(xb): return xb@weights + bias\n",
    "preds = linear1(train_x)\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corrects = (preds>0.0).float() == train_y\n",
    "corrects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corrects.float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights[0] *= 1.0001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = linear1(train_x)\n",
    "((preds>0.0).float() == train_y).float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trgts  = tensor([1,0,1])\n",
    "prds   = tensor([0.9, 0.4, 0.2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_loss(predictions, targets):\n",
    "    return torch.where(targets==1, 1-predictions, predictions).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.where(trgts==1, 1-prds, prds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_loss(prds,trgts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_loss(tensor([0.9, 0.4, 0.8]),trgts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x): return 1/(1+torch.exp(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_loss(predictions, targets):\n",
    "    predictions = predictions.sigmoid()\n",
    "    return torch.where(targets==1, 1-predictions, predictions).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SGD and Mini-Batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = range(15)\n",
    "dl = DataLoader(coll, batch_size=5, shuffle=True)\n",
    "list(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = L(enumerate(string.ascii_lowercase))\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = DataLoader(ds, batch_size=6, shuffle=True)\n",
    "list(dl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting It All Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = init_params((28*28,1))\n",
    "bias = init_params(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = DataLoader(dset, batch_size=256)\n",
    "xb,yb = first(dl)\n",
    "xb.shape,yb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_dl = DataLoader(valid_dset, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = train_x[:4]\n",
    "batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = linear1(batch)\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = mnist_loss(preds, train_y[:4])\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss.backward()\n",
    "weights.grad.shape,weights.grad.mean(),bias.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_grad(xb, yb, model):\n",
    "    preds = model(xb)\n",
    "    loss = mnist_loss(preds, yb)\n",
    "    loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calc_grad(batch, train_y[:4], linear1)\n",
    "weights.grad.mean(),bias.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calc_grad(batch, train_y[:4], linear1)\n",
    "weights.grad.mean(),bias.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights.grad.zero_()\n",
    "bias.grad.zero_();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, lr, params):\n",
    "    for xb,yb in dl:\n",
    "        calc_grad(xb, yb, model)\n",
    "        for p in params:\n",
    "            p.data -= p.grad*lr\n",
    "            p.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(preds>0.0).float() == train_y[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_accuracy(xb, yb):\n",
    "    preds = xb.sigmoid()\n",
    "    correct = (preds>0.5) == yb\n",
    "    return correct.float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_accuracy(linear1(batch), train_y[:4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate_epoch(model):\n",
    "    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n",
    "    return round(torch.stack(accs).mean().item(), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "validate_epoch(linear1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1.\n",
    "params = weights,bias\n",
    "train_epoch(linear1, lr, params)\n",
    "validate_epoch(linear1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "    train_epoch(linear1, lr, params)\n",
    "    print(validate_epoch(linear1), end=' ')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating an Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_model = nn.Linear(28*28,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w,b = linear_model.parameters()\n",
    "w.shape,b.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicOptim:\n",
    "    def __init__(self,params,lr): self.params,self.lr = list(params),lr\n",
    "\n",
    "    def step(self, *args, **kwargs):\n",
    "        for p in self.params: p.data -= p.grad.data * self.lr\n",
    "\n",
    "    def zero_grad(self, *args, **kwargs):\n",
    "        for p in self.params: p.grad = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = BasicOptim(linear_model.parameters(), lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model):\n",
    "    for xb,yb in dl:\n",
    "        calc_grad(xb, yb, model)\n",
    "        opt.step()\n",
    "        opt.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "validate_epoch(linear_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, epochs):\n",
    "    for i in range(epochs):\n",
    "        train_epoch(model)\n",
    "        print(validate_epoch(model), end=' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_model(linear_model, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_model = nn.Linear(28*28,1)\n",
    "opt = SGD(linear_model.parameters(), lr)\n",
    "train_model(linear_model, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = DataLoaders(dl, valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n",
    "                loss_func=mnist_loss, metrics=batch_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(10, lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adding a Nonlinearity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_net(xb): \n",
    "    res = xb@w1 + b1\n",
    "    res = res.max(tensor(0.0))\n",
    "    res = res@w2 + b2\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w1 = init_params((28*28,30))\n",
    "b1 = init_params(30)\n",
    "w2 = init_params((30,1))\n",
    "b2 = init_params(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(F.relu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net = nn.Sequential(\n",
    "    nn.Linear(28*28,30),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(30,1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, simple_net, opt_func=SGD,\n",
    "                loss_func=mnist_loss, metrics=batch_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(40, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(L(learn.recorder.values).itemgot(2));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.recorder.values[-1][2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Going Deeper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = ImageDataLoaders.from_folder(path)\n",
    "learn = cnn_learner(dls, resnet18, pretrained=False,\n",
    "                    loss_func=F.cross_entropy, metrics=accuracy)\n",
    "learn.fit_one_cycle(1, 0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Jargon Recap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. How is a grayscale image represented on a computer? How about a color image?\n",
    "1. How are the files and folders in the `MNIST_SAMPLE` dataset structured? Why?\n",
    "1. Explain how the \"pixel similarity\" approach to classifying digits works.\n",
    "1. What is a list comprehension? Create one now that selects odd numbers from a list and doubles them.\n",
    "1. What is a \"rank-3 tensor\"?\n",
    "1. What is the difference between tensor rank and shape? How do you get the rank from the shape?\n",
    "1. What are RMSE and L1 norm?\n",
    "1. How can you apply a calculation on thousands of numbers at once, many thousands of times faster than a Python loop?\n",
    "1. Create a 3×3 tensor or array containing the numbers from 1 to 9. Double it. Select the bottom-right four numbers.\n",
    "1. What is broadcasting?\n",
    "1. Are metrics generally calculated using the training set, or the validation set? Why?\n",
    "1. What is SGD?\n",
    "1. Why does SGD use mini-batches?\n",
    "1. What are the seven steps in SGD for machine learning?\n",
    "1. How do we initialize the weights in a model?\n",
    "1. What is \"loss\"?\n",
    "1. Why can't we always use a high learning rate?\n",
    "1. What is a \"gradient\"?\n",
    "1. Do you need to know how to calculate gradients yourself?\n",
    "1. Why can't we use accuracy as a loss function?\n",
    "1. Draw the sigmoid function. What is special about its shape?\n",
    "1. What is the difference between a loss function and a metric?\n",
    "1. What is the function to calculate new weights using a learning rate?\n",
    "1. What does the `DataLoader` class do?\n",
    "1. Write pseudocode showing the basic steps taken in each epoch for SGD.\n",
    "1. Create a function that, if passed two arguments `[1,2,3,4]` and `'abcd'`, returns `[(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')]`. What is special about that output data structure?\n",
    "1. What does `view` do in PyTorch?\n",
    "1. What are the \"bias\" parameters in a neural network? Why do we need them?\n",
    "1. What does the `@` operator do in Python?\n",
    "1. What does the `backward` method do?\n",
    "1. Why do we have to zero the gradients?\n",
    "1. What information do we have to pass to `Learner`?\n",
    "1. Show Python or pseudocode for the basic steps of a training loop.\n",
    "1. What is \"ReLU\"? Draw a plot of it for values from `-2` to `+2`.\n",
    "1. What is an \"activation function\"?\n",
    "1. What's the difference between `F.relu` and `nn.ReLU`?\n",
    "1. The universal approximation theorem shows that any function can be approximated as closely as needed using just one nonlinearity. So why do we normally use more?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further Research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Create your own implementation of `Learner` from scratch, based on the training loop shown in this chapter.\n",
    "1. Complete all the steps in this chapter using the full MNIST datasets (that is, for all digits, not just 3s and 7s). This is a significant project and will take you quite a bit of time to complete! You'll need to do some of your own research to figure out how to overcome some obstacles you'll meet on the way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}