mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 18:00:48 +00:00
1767 lines
172 KiB
Plaintext
1767 lines
172 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#hide\n",
|
||
|
"from utils import *"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# fastai Learner from scratch"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This final chapter (other than the conclusion, and the online chapters) is going to look a bit different. We will have far more code, and far less pros than previous chapters. We will introduce new Python keywords and libraries without discussing them. This chapter is meant to be the start of a significant research project for you. You see, we are going to implement all of the key pieces of the fastai and PyTorch APIs from scratch, building on nothing other than the components that we developed in <<chapter_foundations>>! The key goal here is to end up with our own `Learner` class, and some callbacks--enough to be able to train a model on Imagenette, including examples of each of the key techniques we've studied. On the way to building Learner, we will be creating Module, Parameter, and even our own parallel DataLoader… and much more.\n",
|
||
|
"\n",
|
||
|
"The end of chapter questionnaire is particularly important for this chapter. This is where we will be getting you started on the many interesting directions that you could take, using this chapter as your starting out point. What we really saying is: follow through with this chapter on your computer, not on paper, and do lots of experiments, web searches, and whatever else you need to understand what's going on. You've built up the skills and expertise to do this in the rest of this book, so we think you are going to go great!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"path = untar_data(URLs.IMAGENETTE_160)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Path('/home/jhoward/.fastai/data/imagenette2-160/val/n03417042/n03417042_3752.JPEG')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"t = get_image_files(path)\n",
|
||
|
"t[0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Path('/home/jhoward/.fastai/data/imagenette2-160/val/n03417042/n03417042_3752.JPEG')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from glob import glob\n",
|
||
|
"files = L(glob(f'{path}/**/*.JPEG', recursive=True)).map(Path)\n",
|
||
|
"files[0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAANUAAACgCAIAAAAUzb6mAADk7ElEQVR4nHT9WbAt6XUeiK31jznu8cznnjsPNQMoACQAkiAlkiJFSe6W2WyJVsshdfSDw9H2ix1uvyjC4SfbHR1+aClshUMd0WoP3W2FJtKkmrMAkAWABApAoepW3aHueOaz5xz/afkhzzl1ScsZO87dO+/euXNnfv+a17fw//R//i+FEEIIxhgAIKIQgnPOOUfE7i8ihhCIKITgnBNCEJExBgDiOE6ShHPeti0AMMa693vvjTHOubquGGPIiDHGGEMkRAQAAEYBERkRNrWp69b7wJks2znnnL2ydefTnQARIWL3LQAQQmhdGcdpFMU+gLPgfXAuGBdca6SUOpJKcGRA3hF5IkIBUkrGeQjgvQc4P2EA0FprqQC6N1P3yx16zjkRNU1T13UIQWsdx7ExhnMuhOiuiXOuOzERfPcyBIeIDAAgAIQQAgKVZZmm8U//1E/9F//Ff76zs3P//v22Cd57a6211ntHRIQAEACDkDJOdJqmaZomeZqmaRRFm7FujbPWAhdKayaVD8E511rjnHG29d4CBQaByEMgNRo45wBASimlVkJqHUVRJFAQUbGqptPp9as3rLUAAADOUVGthmsDAhvISnRtsdwcDtvlMpWxFJoYbwO0CAa4o9BySUQAgTHWoQUwEJGUEgAQ4bP9AESei34Ivq3rqlhZ0whEJCLvfQgBAIQQHYwukde9oQNfCKHDBBFd4rX7q7UGACLq/l6+v7sxHf4QESB0H6mqBogxxhG5d+ScZ4wLyYbJsDtC99kQgjGGMaaU6i7QJS47LOZxLoTqPoIMMCARedsqpYU8f1vwAUIgCgCAgCF053B+HM5ld26CcYAAgUIIHAmRcc5coBB8CCEED0AAFIK31rRtc3kyHYCJgnMeyRvTlmVZlmVVF6ZumqZyzk2mZ/PpDJF+7dd+raqq999//9vf/jbnPE36QgilVJpFURTpKJKKc85v3LjGxPlKDBi6i4yITbGy1rbWegNFUxKh9c45Nxj1leRJnEVaRUoIwRgSQMA4MsZ47xE4AEBAIYTkvGnq4MmZypvWu4YCdV+RRLEQGAm+LFZC+HyQq2BNU5FzjtrgvA/QEjgheJyqSDOuu5vOGGMMOlSEEIQ8hwdj2O0EACJRVCulVJrqLJWcobj8wKX06sDXHfQSTN12ecWJqJOCF8LsM8x1aDbGtG1rrS3qqrtwjMOl/ENEFUchBApIAT0ED0REnohdXOhL6Hdn1R25O9XuZfelnsCZ1rvgPTEmGAqAAABFuRRCKCWk4JyjYCiEYBw8hUtZTgGJCLFhjMVxbNExDkjgvbPOd5ePOHW/yDnnnSMiB4GCy9K0bduyLOu6bpqmLMvValVVVTGfGGPqum6aqmkaY4wzrXMuiiKkoLWOomhtff0LX/iCMWZnZ6fTPJwj55wQvHfe+0BusZoJIYRkKDiRv/y94ziOVJzwjEkhhGScI2eIyDkD8iE4hMAgkHfkrfdWC5DBSwQhGAIHQMZACIrTmAgVElk77KXBnyso17pBnulYx4rSTG2sD08P4PDp075OyBqOuLm5MdrahUifLVYn04nlsrvgHWwYY93ZIiAyuriJgIjd+o8i0etlg16uJOcIohN43bXmnEspuyvS/eDLZdfhoLsT3vtXRZQxhojqur7E3+V+a20Hmksxeala0zQOAYIH7wOivfy6c3l5sV2IKF5VVXdwROze030RegAAZ71zAZELrkKgENxw2Oecc9apaeeddd50eqH7LZxzFOc45hzqagUAQjKlFEe0YJ1zROTBW2vbtj0H07luDc+ePbPWdsvMGNNp57ZtR4OcMRYpPRj0+4MreZ6nUSyEcM4xxrx1aZoioo4j5AwAymrOOeeyM4HgXBeBX1sfCcGFVt2yuRS0YIzgqgNrCMEH6G4zY0CBBKFgTCulJReMISOFvFO+UZRwLkIIARhjwhofAq1WcSzF9vqo28sYWy0KKTnyEA3S8bg/7Gd2OZsAmGpF1omYlBBZogKXFKxpK573iboF0NlsFAICBCICIkQGBAT+Upz5YKxr6gaq0lnTCCnl5Z2+NPuIiHP+F+Rf97xpmktEdnLOWiulbJrmVah1QOGcZ3F+gWNijHFxbr2VdQXEGBMACAyRMwCGAgU7V6Yd0Dt8dKdHhCFQpwQvsa4jxRgTjHuw3tgQAgJHhLoqhRBScs45BE/giQIReWO994iglELEDpScc6VEXdfLeem9b01dVVXTNCGEqiov5FnToa1tW+dcv9+PoijLsq2NcZqmSZLEcayUCtSEELz3EAiAdecJPqxWZdu229vbgVyWJ1LoPOv3+31gmjEmpVRKKaWEOle6VV0Q666zD0CeAhFDQCkFl7KTGjZ4CAGRJONRrKTATOs8i3pZmsRSSc4Bpe/MKsWl8I7quq0bEwI1xjkX0PumlLEUzjmGDDgb97Omqaq6ivNIcVYXK980nIFvXCSkFGJydnI0OTXIeJKMh/0pyUvhxTlHBp3d5b29VKTnEhGRMZb1Yx1JDlQbU9Wrz+RfB4tX7bwOdh0ILs3/S0B0TkZnO3f38hIoHXY72zxw7G48ke+0ZwdNrXUIgQKEQD5Y5wwAAwgoZIe/S0nTHTaKokvcw8VGRNY7xZUQPITAATszDpGXZYnkkASJ7kd5wAAAWdIry7Ku6tVyXtdluVq1bY1I0+l0OpvMFgsfbHdkJrhSKo61lDJJkvWN4XA47Pf7cRxfClFEtNZ2kq9pi9Zg064u3DglpZRccMGA+HhtyBi7d/f17oevr6+PRqPFYrG5tUVEBOfX9/xXs6DjCM+tJ7pUTZzzYLE7sBACETlHIYQUDMknWg56yaiXJbHmGLxprLUREXlAREBmwLHgwTsi0JyBBwDw9lxTITjGGKcQghOSb61v5L346OBguVwuF0UihHUBGgOOmuAt5/041klMxbn4QuSICARAeOEockSOGLynS2CU1bxpOTlblUVTl4Ix1sHo0qXoQPbq80v5BwCdK3ppBV6aZZ05eGkgXkKksaa7vudWAgHnHDFsbe+1bVvXbVU1oW6sNyFAABmc7X5QCAGAOqnJGPPeXqKfMdaZnd3t735rcB5CYMBa0zrnBnnunGvbej5brlarqiqstY4CtmE2m62KhbfG2KauKyCvlHr9zde0gkGu43SY9/tpmmZZlmRpv5dfSvROsDnbWFMvl519qZRScSTzLO7sFmKjEALQpXPGOuNtsVgI5EkaJWk6XywGw954PG7bdrFavqp/gDPGGGciilMpZScXubp4wjkQf/VkQnAIEDxAcJ6HzpMLznME9MgIXGOdcy4QIvoAIYBSETJe1o0H8i5YHxrjrA2EgTHGjVWxioRMsp5QMhBDpgGlDax1ThHPkjyLohbRAJ4tVsD1K3ZapzMDECKc+xIADIAR+S6UIeUAkYIzDHiktOhk0rmLdP4BICKtdXddXo2GdPe7w9+lsu7Ux+VN6kQmAHSCs3sOSIwh51wIJiTjnGd5whiz1gKE0EVGEDjnHM7FML4S9HnVDLj0fM/dII0YyJrWNq23zllbrZZVUU6nU29t01Srqqzr2jlDCIyxO1fvtE3BwfcG+bC/k2XJYNgbDgebm+tnZ2fz1RIYEkDVNsbYVbFcLqZKqTiOoyjSWsdxorUWQly/fvXSPOjM5S6GomIdQujcHCIKPngKwXshxPPnz9fX19/+hbeV1NPZbLlaSalTlaLgl/pXSskER8ThcMjkZ3ExOlciQHT+EhlqyTlnnIFAAGc5hOCoWLRm1TAMEFxwVlBojXPOAbAADJAJpQUXxjMTfOuhcWQ8GBcIkTHMdSIiXdbV8dlcSl62Phusr+94DKSZSLJUxvHSmrquUWqd9cFHF8oXO73nPffehhCklFKJDgmdwyqEaF2JgEAA4ICkmM/n3Tru5PmlAdcd69Lsu0SAlLKLFYWL7VJrv+IBUfeVRCQ4j6JoOp1kWYZAq+Wy38/LVfHxR/eTJKnrtq5ryRlQaNu2l+fOWMZ42zTOuTRNOedN0wRyABC6gCFnRL6q6rquvfd
|
||
|
"text/plain": [
|
||
|
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=213x160 at 0x7F471FDA55D0>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"im = Image.open(files[0])\n",
|
||
|
"im"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(#10) ['n03417042','n03445777','n03888257','n03394916','n02979186','n03000684','n03425413','n01440764','n03028079','n02102040']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"lbls = files.map(Self.parent.name()).unique(); lbls"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'n03417042': 0,\n",
|
||
|
" 'n03445777': 1,\n",
|
||
|
" 'n03888257': 2,\n",
|
||
|
" 'n03394916': 3,\n",
|
||
|
" 'n02979186': 4,\n",
|
||
|
" 'n03000684': 5,\n",
|
||
|
" 'n03425413': 6,\n",
|
||
|
" 'n01440764': 7,\n",
|
||
|
" 'n03028079': 8,\n",
|
||
|
" 'n02102040': 9}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"v2i = lbls.val2idx(); v2i"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([160, 213, 3])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"im_t = tensor(im)\n",
|
||
|
"im_t.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Dataset:\n",
|
||
|
" def __init__(self, fns): self.fns=fns\n",
|
||
|
" def __len__(self): return len(self.fns)\n",
|
||
|
" def __getitem__(self, i):\n",
|
||
|
" im = Image.open(self.fns[i]).resize((64,64)).convert('RGB')\n",
|
||
|
" y = v2i[self.fns[i].parent.name]\n",
|
||
|
" return tensor(im).float()/255, tensor(y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(9469, 3925)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_filt = L(o.parent.parent.name=='train' for o in files)\n",
|
||
|
"train,valid = files[train_filt],files[~train_filt]\n",
|
||
|
"len(train),len(valid)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(torch.Size([64, 64, 3]), tensor(0))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_ds,valid_ds = Dataset(train),Dataset(valid)\n",
|
||
|
"x,y = train_ds[0]\n",
|
||
|
"x.shape,y"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAACMCAYAAABcUNbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO19eayc13Xf737r7MvbV27iIlGkSGuxZclK5CS2ZDuJjNhN6sQx7NRJ6hRokKBIWycp0hZtDBQBCnRH2zRwkzRI09rZvMixK3mRtVk7xU0kH/keH9++zD7zbf3j/O6d4ZM4VOyoBvrmAOSQ33zzfXc7557zO8tVSZJgQLuDrO93Awb0/44Gk72LaDDZu4gGk72LaDDZu4gGk72LaDDZu4i+r5OtlHqHUuoJpVRLKXVNKfXbSim75/tZpdSjSqlFpVSbn59VSs3c4HmWUuqrSqlEKfXRHd/9ulLqG0qpCr+f2fH9Pl5/oz//bse9H1dKnWWbziilfqZPH48qpepKqXDH9UeUUl9QSi0ppRpKqVNKqV9WSqm/zhj+dej7NtlKqVkAXwFwFsBdAD4F4BcB/Iue20IA/wvAjwE4BOAnARwG8Oc3eOw/AdC4wXc+gD/b8fxemgcwuePP3+J3f9TT7g8C+K8A/iOAEwD+M4DPKqXe9wZ9zAD4YwBfe4P3PQjg2wB+AsAxAL8D4DMAfu0G7fveKUmSt+QPgMcA/BcAvwlgCcAGgN8DkOX3/xLAAgCr5zd/D0Bd33OD5z4CIAFQ3HH93QCuABjm9x+9we8f5Pczb6IPfwDg1I5rTwD4wx3X/ieAx97g9/8Nsig+DiB8E+/7NwC+81bNyVvN2R8GMMQB/mkAH0R35d4P4NEkSeKe+78EIAPgbW/0MKXUCICfBfBckiTbPdfHAfx3AB9LkmT9b6LhfNeHAPynnmsegHvYzl76EoB7d2xBH+O9v/LXeG0RwNp32+ab0Vs92VeSJPmVJEnOJEnyJYg4fC+/m4RwfC8t9XxnSCn1P5RSDQCrAKYAPNzznQXhwN9NkuSxv8G2fxxADFlEmkYAODdotw9Z2FBK3QYRy387SZLmm3mZUupBAB8B8G+/l0b3o7d6sl/Y8f+rAMb73J/s+NT0KxBu1/viH/Vw0acBpAD80++hndcRlaRfAPDHSZJs/jV+miilfIhY/40kSV55k++7F8DnAfxWkiQ30ke+d3qr9+wd134DwBz//TiEG3u/PwCZ6Hf1ee4k73m45z0RRJnTfxJeO/Pd7NkAfpj33LvjugcggGwXvdd/DkALgA1gH3/b256o59qn36A91Z3X34o/zt/Mkvmu6FsAflYpZfXs2w9DtOnn+/xOSyOfn58AkN1xz8sAfh2iyX839IsAXkqS5Mnei0mSdJRSzwB4CMBne756GMCTSZJESqmrAI7veN4jEMlzEsCyvqiU+gBECvxmkiS/81229c3T95GzZwFUIGbM7QB+HMA6gM/03P8hAB+DmCZ7AfwIZJHMA8j3effrtHEAeyCD/Ul+/17+f2jHfWMAOgB+6QbP/iCEQ38ZwBEAv8r/v69Pez6OHdo4xKzrAPhnACZ6/oy+ZXPy/Zps/v9eiCnTgig5vw3A7vn+RwE8CWCL91wA8B8AzN7k3W802b/H6zv/fHzHff8IQA1A4SaTd46TdXbnu97kZD92g/bM9XvW9/JH8cUD2gU0wMZ3EQ0mexfRYLJ3EQ0mexfRYLJ3EfUFVVKTo2LHpMfgFicAANmcfNdpbAAA3PQwKtvik3DtAgDA3n8UrdFhAICnBPtwkzSipAIAsDxPXp4pshEKCZddYnXk3dubiGy56I4fkN8BcFtb8puUvMvx+dmpYOn8aQDAxB65P1soIe5UAQDp4ggAoL22gKBRBwAMT+4HAKgtgbrPPfY52FXxQ/gpaRsyaZQyGbnPTcm1chkAMD48gUxe/p2ypJ/loWFsBgKH11YFad2sNrCnZHHcpD3raysAgCunn8TRh/8+AKB6+Kfk3ck8MlV5Z8WVZ1mdBtrXTgEAXv2GIKrzZwV7+sAjH8LKhnh2H/vDf31Df/iAs3cR9eVsNyUopMrmEDvid2h12gAEMgKAyG4jOyar0IZwbJRLoPyY1/iiJIYiKqo8se0d1+J3FmDJgkxsuce1begl6lhyn63kOgA45Hr+F77roDRUNv/Wz0gc+bdny2fiOIAlP/L4nUpJ+4vD40jlRXTlCiIJ3FIJDn0ue2b3yrtH5D1JK0Q7kPaW0iJhvv30U1isCEdHNRmrfKGAw6MibQKOwVB5CACw6PiobgqX5zkuKvLhp2Qssxx31y3Cn5Y+7/mA/PbFMXnm4pqDk3f9KG5GfSfby+UBAHEqhciRF7WDSAaD33WsDoJIRK/jSGMtVYFlcaEgx09AEb9RHGzFSVSx0nNtqFWtwsvKM1xG6jhKwVEU7ZxlS3+HBLm0iFKfA5RyLESR3O/pVWHZSPgyn5OduPK7qf2HMTss4ltPdttNw7ddAMC1hXkAwKXnRZx2tqsol0oAgHuOnwQArG9uImBfUmlpfyblYXxY7ktaspWtbYo4L5fHUVm6Iu/nNhe5GfgWtzqOkRUoZFISSZUbGZPfjsk2FAZruDz3MkfuR3AjGojxXUR9ObvREOXA9gLky7LSbYuKw+pVAEDYrsC1hNsjtAAAXrGIdGFU7nco4pMESUKOTnx+2nxmDJscalF4N2p1eFTkPKuHs7nSXX46tnznxt2tIE3O9pRCTI5OaUlg21DkaF8/Iy1tnDx0O7ZWRVm7dEG4uB53uXZxWZTSelskWSaVRRJQVKelrb5jo0POtmN5Zy7j48H73w4AOHdKuP3Rx58CAIyP78G5OeFsr74IAIhLh6Es2SjTHA+lQiSJjHPsyXPLJYnxCNZW8b+/xhiL3/oUbkQDzt5F1JezHXKIQoA4Ei5PItl4XUf2sXSUR9imSRXLyutsbsMabvMZAQAg5SRIuGlHXKE2OdFWCs4Ozk45Dlx+72kFDcr8eydne3HXfPP46TsWwlA/g893bKPVZbRpRx2i02xifklMLyctukY5n0LOld/mfOlzJaD+0mxjfFjMsff84DsAAE898yQurIuC5lFy5TM+sr4M9UhRnht0ZFyGciX4oZiMqIg08YdOIFDbpl8AEFshLN3XrMuxpY4Uu1BbNw9NGHD2LqK+yyFsi6GuXAvNhqwmn/uo1oaRKiCOZZUGbeHmVivAaIaathJOScIOXF++tyH3gxq7bXmmIbbqcq7mYo+cZ8Eye7aj9/EeDtfXXHKs59iw2E6tjduOA8tP8RqlSCKfLddFOkfOU9LfoFHF/ScPAQBW5y8BAC5ekz3Wt3yMU3sfzcu+PzFcRMA9tZAWE2l0KANeQiGTNn2R8cygnJZ/V1fm5BkHHYSUdHZHj6NvrJgkjPkM2ddfPfUKllfmcTPqz/tUrsJmCy5ENCkqM86YdAQJENRk8OKUKC7+5ChiRQUtqAEAGuESfCWvS8k4wlNjbLQHZWsxnmLnbFgOO+zKc11E8NhRT8lC8TlonuoqaDZtdduxjbjPQraOwPXh0fr3OaBF2uUnb7kFLz8vqNRmS55RcDuYGhMbemJY3jkzJYrR1MwMbpkUpNB1pVM538cwzb3hnDBG3gPWFmUyFhdECXNc+W6z3UE6L8pvsHBW+m5vIxPSdNUWY6KM6RoSz7hw8QwA4Et/9Xl4Q9K/fjQQ47uI+oMqeeHeqFNBVBUQIEnkJ7WmKGXFbAo2laBUXla3PVyCa4t4S6pirpRSCSxKbwSCTbtZiiPbASxyI8WnYyu4XMEpinsXLXi2PMShaaI5xAEQ1AlU5AXASMdAQs7OkC0algNLK3JkmzSRvCPTIyhpqUMF9MSth5BLi7QpZ+TaZFn6dmBmFBlKjEtnXpK2Jk2UqNjmIEptWuXwygvPAQC2tuVaIS/SYqVSw8joFABgde48ACAOVpF3DgIAWhCppiIFW8kYdbjt6I5MjM9iM67gZjTg7F1EfTnbAnHqXA4O97VOWzgqqMkKrbc7SHuidGhzwm+2kCnJXt2KhLOd9VWojnB
|
||
|
"text/plain": [
|
||
|
"<Figure size 144x144 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"show_image(x, title=lbls[y]);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def collate(idxs, ds): \n",
|
||
|
" xb,yb = zip(*[ds[i] for i in idxs])\n",
|
||
|
" return torch.stack(xb),torch.stack(yb)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(torch.Size([2, 64, 64, 3]), tensor([0, 0]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x,y = collate([1,2], train_ds)\n",
|
||
|
"x.shape,y"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class DataLoader:\n",
|
||
|
" def __init__(self, ds, bs=128, shuffle=False, n_workers=1):\n",
|
||
|
" self.ds,self.bs,self.shuffle,self.n_workers = ds,bs,shuffle,n_workers\n",
|
||
|
"\n",
|
||
|
" def __len__(self): return (len(self.ds)-1)//self.bs+1\n",
|
||
|
"\n",
|
||
|
" def __iter__(self):\n",
|
||
|
" idxs = L.range(self.ds)\n",
|
||
|
" if self.shuffle: idxs = idxs.shuffle()\n",
|
||
|
" chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]\n",
|
||
|
" with ProcessPoolExecutor(self.n_workers) as ex:\n",
|
||
|
" yield from ex.map(collate, chunks, ds=self.ds)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(torch.Size([128, 64, 64, 3]), torch.Size([128]), 74)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"n_workers = min(16, defaults.cpus)\n",
|
||
|
"train_dl = DataLoader(train_ds, bs=128, shuffle=True, n_workers=n_workers)\n",
|
||
|
"valid_dl = DataLoader(valid_ds, bs=256, shuffle=False, n_workers=n_workers)\n",
|
||
|
"xb,yb = first(train_dl)\n",
|
||
|
"xb.shape,yb.shape,len(train_dl)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[tensor([0.4544, 0.4453, 0.4141]), tensor([0.2812, 0.2766, 0.2981])]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"stats = [xb.mean((0,1,2)),xb.std((0,1,2))]\n",
|
||
|
"stats"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Normalize:\n",
|
||
|
" def __init__(self, stats): self.stats=stats\n",
|
||
|
" def __call__(self, x):\n",
|
||
|
" if x.device != self.stats[0].device:\n",
|
||
|
" self.stats = to_device(self.stats, x.device)\n",
|
||
|
" return (x-self.stats[0])/self.stats[1]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"norm = Normalize(stats)\n",
|
||
|
"def tfm_x(x): return norm(x).permute((0,3,1,2))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([0.3732, 0.4907, 0.5633]), tensor([1.0212, 1.0311, 1.0131]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"t = tfm_x(x)\n",
|
||
|
"t.mean((0,2,3)),t.std((0,2,3))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Module and Parameter"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Parameter_(Tensor):\n",
|
||
|
" def __init__(self, *args, **kwargs): self.requires_grad_()\n",
|
||
|
"def Parameter(x): return Tensor._make_subclass(Parameter_, x, True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Module:\n",
|
||
|
" def __init__(self):\n",
|
||
|
" self.hook,self.params,self.children,self._training = None,[],[],False\n",
|
||
|
" \n",
|
||
|
" def register_parameters(self, *ps): self.params += ps\n",
|
||
|
" def register_modules (self, *ms): self.children += ms\n",
|
||
|
" \n",
|
||
|
" @property\n",
|
||
|
" def training(self): return self._training\n",
|
||
|
" @training.setter\n",
|
||
|
" def training(self,v):\n",
|
||
|
" self._training = v\n",
|
||
|
" for m in self.children: m.training=v\n",
|
||
|
" \n",
|
||
|
" def parameters(self):\n",
|
||
|
" res = self.params\n",
|
||
|
" res += sum([m.parameters() for m in self.children], [])\n",
|
||
|
" return res\n",
|
||
|
"\n",
|
||
|
" def __setattr__(self,k,v):\n",
|
||
|
" super().__setattr__(k,v)\n",
|
||
|
" if isinstance(v,Parameter_): self.register_parameters(v)\n",
|
||
|
" if isinstance(v,Module): self.register_modules(v)\n",
|
||
|
" \n",
|
||
|
" def __call__(self, *args, **kwargs):\n",
|
||
|
" res = self.forward(*args, **kwargs)\n",
|
||
|
" if self.hook is not None: self.hook(res, args)\n",
|
||
|
" return res\n",
|
||
|
" \n",
|
||
|
" def cuda(self):\n",
|
||
|
" for p in self.parameters(): p.data = p.data.cuda()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"act_func = F.relu"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ConvLayer(Module):\n",
|
||
|
" def __init__(self, ni, nf, stride=1, bias=True, act=True):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.w = Parameter(torch.zeros(nf,ni,3,3))\n",
|
||
|
" self.b = Parameter(torch.zeros(nf)) if bias else None\n",
|
||
|
" self.act,self.stride = act,stride\n",
|
||
|
" init = nn.init.kaiming_normal_ if act else nn.init.xavier_normal_\n",
|
||
|
" init(self.w)\n",
|
||
|
" \n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = F.conv2d(x, self.w, self.b, stride=self.stride, padding=1)\n",
|
||
|
" if self.act: x = act_func(x)\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"2"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"l = ConvLayer(3, 4)\n",
|
||
|
"len(l.parameters())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([128, 4, 64, 64])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"xbt = tfm_x(xb)\n",
|
||
|
"r = l(xbt)\n",
|
||
|
"r.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Linear(Module):\n",
|
||
|
" def __init__(self, ni, nf):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.w = Parameter(torch.zeros(nf,ni))\n",
|
||
|
" self.b = Parameter(torch.zeros(nf))\n",
|
||
|
" nn.init.xavier_normal_(self.w)\n",
|
||
|
" \n",
|
||
|
" def forward(self, x): return x@self.w.t() + self.b"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([3, 2])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 29,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"l = Linear(4,2)\n",
|
||
|
"r = l(torch.ones(3,4))\n",
|
||
|
"r.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class T(Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.c,self.l = ConvLayer(3,4),Linear(4,2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"4"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 31,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"t = T()\n",
|
||
|
"len(t.parameters())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"device(type='cuda', index=5)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 32,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"t.cuda()\n",
|
||
|
"t.l.w.device"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Simple CNN"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Sequential(Module):\n",
|
||
|
" def __init__(self, *layers):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.layers = layers\n",
|
||
|
" self.register_modules(*layers)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" for l in self.layers: x = l(x)\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class AdaptivePool(Module):\n",
|
||
|
" def forward(self, x): return x.mean((2,3))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def simple_cnn():\n",
|
||
|
" return Sequential(\n",
|
||
|
" ConvLayer(3 ,16 ,stride=2), #32\n",
|
||
|
" ConvLayer(16,32 ,stride=2), #16\n",
|
||
|
" ConvLayer(32,64 ,stride=2), # 8\n",
|
||
|
" ConvLayer(64,128,stride=2), # 4\n",
|
||
|
" AdaptivePool(),\n",
|
||
|
" Linear(128, 10)\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"10"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 36,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"m = simple_cnn()\n",
|
||
|
"len(m.parameters())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def print_stats(outp, inp): print (outp.mean().item(),outp.std().item())\n",
|
||
|
"for i in range(4): m.layers[i].hook = print_stats"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.41993722319602966 0.6984175443649292\n",
|
||
|
"0.31132981181144714 0.5268176794052124\n",
|
||
|
"0.30335918068885803 0.547913670539856\n",
|
||
|
"0.32568952441215515 0.5254442095756531\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([128, 10])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"r = m(xbt)\n",
|
||
|
"r.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def nll(input, target): return -input[range(target.shape[0]), target].mean()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"And check the kind of result it gives (since our model was trained for two epoch, this should be pretty low):"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 40,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(-2.7753, grad_fn=<SelectBackward>)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sm = log_softmax(r); sm[0][0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 42,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(2.5293, grad_fn=<NegBackward>)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 42,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"loss = nll(sm, yb)\n",
|
||
|
"loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Note that the formula \n",
|
||
|
"\n",
|
||
|
"$$\\log \\left ( \\frac{a}{b} \\right ) = \\log(a) - \\log(b)$$ \n",
|
||
|
"\n",
|
||
|
"gives a simplification when we compute the log softmax, which was previously defined as `(x.exp()/(x.exp().sum(-1,keepdim=True))).log()`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 43,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(-2.7753, grad_fn=<SelectBackward>)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 43,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()\n",
|
||
|
"sm = log_softmax(r); sm[0][0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:\n",
|
||
|
"\n",
|
||
|
"$$\\log \\left ( \\sum_{j=1}^{n} e^{x_{j}} \\right ) = \\log \\left ( e^{a} \\sum_{j=1}^{n} e^{x_{j}-a} \\right ) = a + \\log \\left ( \\sum_{j=1}^{n} e^{x_{j}-a} \\right )$$\n",
|
||
|
"\n",
|
||
|
"where a is the maximum of the $x_{j}$.\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"Here's the same thing in code:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(False)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.rand(5)\n",
|
||
|
"a = x.max()\n",
|
||
|
"x.exp().sum().log() == a + (x-a).exp().sum().log()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 45,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def logsumexp(x):\n",
|
||
|
" m = x.max(-1)[0]\n",
|
||
|
" return m + (x-m[:,None]).exp().sum(-1).log()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 46,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(2.3158, grad_fn=<SelectBackward>)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 46,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"logsumexp(r)[0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"So we can use it for our `log_softmax` function."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 47,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 48,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(-2.7753, grad_fn=<SelectBackward>)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 48,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sm = log_softmax(r); sm[0][0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 49,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def cross_entropy(preds, yb): return nll(log_softmax(preds), yb).mean()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Learner"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 50,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SGD:\n",
|
||
|
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
|
||
|
" def step(self):\n",
|
||
|
" for p in self.params:\n",
|
||
|
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
||
|
" p.grad.data.zero_()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 51,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class DataLoaders:\n",
|
||
|
" def __init__(self, *dls): self.train,self.valid = dls"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 52,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dls = DataLoaders(train_dl,valid_dl)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 53,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Learner:\n",
|
||
|
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
||
|
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
|
||
|
" for cb in cbs: cb.learner = self\n",
|
||
|
"\n",
|
||
|
" def one_batch(self):\n",
|
||
|
" self('before_batch')\n",
|
||
|
" xb,yb = self.batch\n",
|
||
|
" self.preds = self.model(xb)\n",
|
||
|
" self.loss = self.loss_func(self.preds, yb)\n",
|
||
|
" if self.model.training:\n",
|
||
|
" self.loss.backward()\n",
|
||
|
" self.opt.step()\n",
|
||
|
" self('after_batch')\n",
|
||
|
"\n",
|
||
|
" def one_epoch(self, train):\n",
|
||
|
" self.model.training = train\n",
|
||
|
" self('before_epoch')\n",
|
||
|
" dl = self.dls.train if train else self.dls.valid\n",
|
||
|
" for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):\n",
|
||
|
" self.one_batch()\n",
|
||
|
" self('after_epoch')\n",
|
||
|
" \n",
|
||
|
" def fit(self, n_epochs):\n",
|
||
|
" self('before_fit')\n",
|
||
|
" self.opt = self.opt_func(self.model.parameters(), self.lr)\n",
|
||
|
" self.n_epochs = n_epochs\n",
|
||
|
" try:\n",
|
||
|
" for self.epoch in range(n_epochs):\n",
|
||
|
" self.one_epoch(True)\n",
|
||
|
" self.one_epoch(False)\n",
|
||
|
" except CancelFitException: pass\n",
|
||
|
" self('after_fit')\n",
|
||
|
" \n",
|
||
|
" def __call__(self,name):\n",
|
||
|
" for cb in self.cbs: getattr(cb,name,noop)()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 54,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Callback(GetAttr): _default='learner'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 55,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class TrackResults(Callback):\n",
|
||
|
" def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]\n",
|
||
|
" \n",
|
||
|
" def after_epoch(self):\n",
|
||
|
" n = sum(self.ns)\n",
|
||
|
" print(self.epoch, self.model.training,\n",
|
||
|
" sum(self.losses).item()/n, sum(self.accs).item()/n)\n",
|
||
|
" \n",
|
||
|
" def after_batch(self):\n",
|
||
|
" xb,yb = self.batch\n",
|
||
|
" acc = (self.preds.argmax(dim=1)==yb).float().sum()\n",
|
||
|
" self.accs.append(acc)\n",
|
||
|
" n = len(xb)\n",
|
||
|
" self.losses.append(self.loss*n)\n",
|
||
|
" self.ns.append(n)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 56,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SetupLearnerCB(Callback):\n",
|
||
|
" def before_batch(self):\n",
|
||
|
" xb,yb = to_device(self.batch)\n",
|
||
|
" self.learner.batch = tfm_x(xb),yb\n",
|
||
|
"\n",
|
||
|
" def before_fit(self): self.model.cuda()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 57,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"cbs = [SetupLearnerCB(),TrackResults()]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 58,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 75,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 True 2.1446147873983525 0.21987538282817615\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 False 2.0508458150875795 0.26343949044585985\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.fit(1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Annealing"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 59,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class LRFinder(Callback):\n",
|
||
|
" def before_fit(self):\n",
|
||
|
" self.losses,self.lrs = [],[]\n",
|
||
|
" self.learner.lr = 1e-6\n",
|
||
|
" \n",
|
||
|
" def before_batch(self):\n",
|
||
|
" if not self.model.training: return\n",
|
||
|
" self.opt.lr *= 1.2\n",
|
||
|
"\n",
|
||
|
" def after_batch(self):\n",
|
||
|
" if not self.model.training: return\n",
|
||
|
" if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException\n",
|
||
|
" self.losses.append(self.loss.item())\n",
|
||
|
" self.lrs.append(self.opt.lr)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 60,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"lrfind = LRFinder()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 61,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 62,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 True 2.490109584565424 0.10507973386841271"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 False 2.332222830414013 0.09859872611464968"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"\n",
|
||
|
" <div>\n",
|
||
|
" <style>\n",
|
||
|
" /* Turns off some styling */\n",
|
||
|
" progress {\n",
|
||
|
" /* gets rid of default border in Firefox and Opera. */\n",
|
||
|
" border: none;\n",
|
||
|
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||
|
" background-size: auto;\n",
|
||
|
" }\n",
|
||
|
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
||
|
" background: #F44336;\n",
|
||
|
" }\n",
|
||
|
" </style>\n",
|
||
|
" <progress value='14' class='' max='74', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||
|
" 18.92% [14/74 00:02<00:12]\n",
|
||
|
" </div>\n",
|
||
|
" "
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.fit(2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 113,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD/CAYAAAD2Qb01AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3ic1ZX48e8ZjXqvtixbkiu2McGN4tBMC21TCCSB1N0UErIkm+wmYTdZFnaXDWx+yW42jQSWJCQhbAiQAAmQhGIMGAw2tsG9SnJX79LU8/tjikbSjDTqI8/5PM88jN736tUdWZy5c95z7xVVxRhjTPJwTHUHjDHGTC4L/MYYk2Qs8BtjTJKxwG+MMUnGAr8xxiQZC/zGGJNknFPdgXiUlJRodXX1VHfDGGOmlc2bNzeqaunA49Mi8FdXV7Np06ap7oYxxkwrIlIb7bileowxJslY4DfGmCRjgd8YY5KMBX5jjEkyFviNMSbJWOA3xpgkY4HfGGMSUEevhxNtvfj94790vgV+Y4xJQL96rY5z73oOt88/7te2wG+MMQnIEwz4qSnjH6Yt8BtjTALy+PykOIQUh4z7tS3wG2NMAnL7/DgnIOiDBX5jjElIHq+SNgFpHrDAb4wxCcnr95PqtMBvjDFJw+Pzk5piqR5jjEkabq9OSEUPWOA3xpiE5PH5pzbHLyK/EpHjItIuIntF5NMx2v1YRDojHi4R6Yg4v05EeiPO7xmvF2KMMacSj8+Pc4pTPXcB1aqaB7wHuFNEVg1spKqfU9Wc0AN4CPjtgGa3RLQ5bUy9N8aYU1Qgxz+FI35V3aGqrtCXwcf8ob5HRLKB64AHxtRDY4xJQh5fAuT4ReRHItIN7AaOA08N8y3XAQ3A+gHH7xKRRhF5RUTWjqSzxhiTLKY8xw+gqp8HcoELgMcA19DfwSeAX6hq5NJytwLzgArgXuBJEYn6yUFEbhKRTSKyqaGhId5uGmPMKcHj85PqTIByTlX1qerLwGzg5ljtRGQOcBHwiwHfv1FVO1TVpaoPAK8AV8f4Wfeq6mpVXV1aWjqSbhpjzLTnToRUzwBOhs7xfxzYoKoHh7mOAhPzlmaMMdOYx+vH6ZiiwC8iZSJyg4jkiEiKiFwB3Ag8P8S3fRz4+YDrFIjIFSKSISJOEfkIcCHwpzH03xhjTkken5+0CUr1OONoowTSOj8m8EZRC3xJVR8XkUpgJ7BUVesARGQNgVTQwDLOVOBOYDHgI3CT+H2qarX8xhgzgNc/cameYQO/qjYQyNdHO1cH5Aw49iqQHeM6Z42um8YYk1zc3imu4zfGGDO5pnwClzHGmMkVqONPgHJOY4wxk8PjU5w24jfGmOThtlSPMcYkF6+leowxJnn4/IpfsRG/McYkC4/PD2B77hpjTLJwhwK/jfiNMSY5eLyhwG85fmOMSQoeX2A1exvxG2NMkvBYqscYY5JLX+C3VI8xxiSFUKpnyrdeNMYYMzks1WOMMUkmVM7ptFSPMcYkh1A5p6V6jDEmSXj9wXJOm7lrjDHJwWbuGmNMkkmImbsi8isROS4i7SKyV0Q+HaPdX4uIT0Q6Ix5rI85Xi8gLItItIrtF5LJxeh3GGHPKSJSZu3cB1aqaB7wHuFNEVsVo+6qq5kQ81kWcewjYAhQD3wAeEZHSUfbdGGNOSQlRzqmqO1TVFfoy+Jg/kh8kIouAlcDtqtqjqo8CbwPXjeQ6xhhzqnMnysxdEfmRiHQDu4HjwFMxmq4QkcZgSug2EXEGj58OHFTVjoi224LHjTHGBIVG/FNezqmqnwdygQuAxwBXlGbrgWVAGYGR/I3AV4PncoC2Ae3bgtccRERuEpFNIrKpoaEh3m4aY8y0502QHD8AqupT1ZeB2cDNUc4fVNVDqupX1beBfwOuD57uBPIGfEse0EEUqnqvqq5W1dWlpXYbwBiTPBJ1By4n8eX4FQglqXYA80QkcoR/ZvC4McaYoPCSDY4pyvGLSJmI3CAiOSKSIiJXEEjhPB+l7VUiMiP4fDFwG/A4gKruBbYCt4tIhohcC7wDeHT8Xo4xxkx/Hu/Up3qUQFrnCNACfBv4kqo+LiKVwVr9ymDbS4G3RKSLwM3fx4BvRlzrBmB18Dp3A9erqiXwjTEmgsfnJ8UhpEzQiN85XINgYL4oxrk6AjdtQ19/BfjKENeqAdaOtJPGGJNMPD7/hJVygi3ZYIwxCcfj0wlL84AFfmOMSTgen3/CavjBAr8xxiQcj88/YZuwgAV+Y4xJOG6f31I9xhiTTDw+tVSPMcYkE4/XRvzGGJNUvH4/qU7L8RtjTNJwWzmnMcYkF4/XT6rDAr8xxiQNj89SPcYYk1Q8Vs5pjDHJxXL8xhiTZLy2ZIMxxiQXW53TGGOSjMenOG3Eb4wxycPW6jHGmCQTWJbZUj3GGJM0bK0eY4xJMh6/kuq0wG+MMUlBVRNjApeI/EpEjotIu4jsFZFPx2j3CRHZHGx3RES+JSLOiPPrRKRXRDqDjz3j9UKMMeZU4PMrqpDqmPoc/11AtarmAe8B7hSRVVHaZQFfAkqAc4BLga8MaHOLquYEH6eNst/GGHNK8vgUYEJTPc7hm4Cq7oj8MviYD2we0O6eiC+PisiDwMVj7aQxxiQLt88PMPWpHgAR+ZGIdAO7gePAU3F824XAjgHH7hKRRhF5RUTWDvHzbhKRTSKyqaGhId5uGmPMtOYJBv6EKOdU1c8DucAFwGOAa6j2IvI3wGrg2xGHbwXmARXAvcCTIjI/xs+7V1VXq+rq0tLSeLtpjDHTmjeU6kmEET+AqvpU9WVgNnBzrHYi8j7gbuAqVW2M+P6Nqtqhqi5VfQB4Bbh6dF03xphTj2cSUj1x5fhjfF/UkbqIXAncB1yjqm8Pcx0FJu7zjDHGTDOhHL9zKlM9IlImIjeISI6IpIjIFcCNwPNR2l4CPAhcp6qvDzhXICJXiEiGiDhF5CME7gH8aXxeijHGTH99Of6pTfUogbTOEaCFQM7+S6r6uIhUBuvxK4NtbwPygaciavWfDp5LBe4EGoBG4AvA+1TVavmNMSbI4534HP+wqR5VbQAuinGuDsiJ+Dpm6WbwOmeNoo/GGJM0wuWctmSDMcYkB2/45m4ClHMaY4yZeKGZu1Od4zfGGDNJPOGqHgv8xhiTFNyW6jHGmOSSKOWcxhhjJslkzNy1wD/O9pzo4J51B6a6G8aYaWoylmW2wD/OHttyhP98Zjfdbu9Ud8UYMw15LMc//TR1ugFo6fZMcU+MMdORxxsM/A4b8U8bTZ2B1aqbg28AxhgzEpbqmYaaugIBv7nbAr8xZuSsnHMaCqd6uizwG2NGLpzjt1TP9KCqNHUFUz0W+I0xo+D1KU6H4HDYiH9a6Hb76PUE3q1bkjzVs7m2hdt+vx1VnequGDOteHz+Ca3hBwv846op4oZuso/4n9x2jF++Vhu+52GMiY/b55/Q3bfAAv+4auzq238+2Uf8h5u7Aaht6p7inhgzvXh8/gldrgEs8I+r0Ig/MzVl2oz43V4/9e29437d2mDgr2vuGvdrG3Mq83jVUj3TSaiGf0FZzrQJ/A9sqOGCb71AbdP4BWi/X8Mj/rqmnnG7rjHJwOPzk+q0VM+0EcpnLyzLoblreszcPdjYhcvr5+6nd4/bNRs6XbiCsw9rbcRvzIh4/Aky4heRX4nIcRFpF5G9IvLpIdp+WUROiEibiPxURNIjzlWLyAsi0i0iu0XksvF4EYmiqdNNTrqTmfkZtHS7p0VFSyjN8/T2E2w82DQu1wzl9Z0Ooc5y/MaMiMebODn+u4BqVc0D3gPcKSKrBjYSkSuAfwQuBaqBecC/RjR5CNgCFAPfAB4RkdJR9z7BNHW5KM5Joyg7DZ9fae9N/IXaTnb0smZeMbPyM/j3P+7E7x/7m1VdMM2zsrIw/NwYEx9PolT1qOoOVQ2VrGjwMT9K008A9wfbtwD/Dvw1gIgsAlYCt6tqj6o+CrwNXDe2l5A4mjrdFGenUZiVBkyP2bsn211UFWdx61WL2X60nUffPDLma9Y1d+MQWDO/mPoOFz1u3zj
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(lrfind.lrs[:-2],lrfind.losses[:-2])\n",
|
||
|
"plt.xscale('log')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 136,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class OneCycle(Callback):\n",
|
||
|
" def __init__(self, base_lr): self.base_lr = base_lr\n",
|
||
|
" def before_fit(self): self.lrs = []\n",
|
||
|
"\n",
|
||
|
" def before_batch(self):\n",
|
||
|
" if not self.model.training: return\n",
|
||
|
" n = len(self.dls.train)\n",
|
||
|
" bn = self.epoch*n + self.num\n",
|
||
|
" mn = self.n_epochs*n\n",
|
||
|
" pct = bn/mn\n",
|
||
|
" pct_start,div_start = 0.25,10\n",
|
||
|
" if pct<pct_start:\n",
|
||
|
" pct /= pct_start\n",
|
||
|
" lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr\n",
|
||
|
" else:\n",
|
||
|
" pct = (pct-pct_start)/(1-pct_start)\n",
|
||
|
" lr = (1-pct)*self.base_lr\n",
|
||
|
" self.opt.lr = lr\n",
|
||
|
" self.lrs.append(lr)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 137,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"onecyc = OneCycle(0.1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 138,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[onecyc])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 139,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 True 2.2302985812255782 0.17985003696272045\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0 False 2.197772939888535 0.1819108280254777\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1 True 2.0975320783609672 0.24215862287464357\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1 False 1.9934868879378982 0.3026751592356688\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"2 True 1.9510114006890906 0.3118597528778118\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"2 False 1.8667540804140128 0.35337579617834397\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"3 True 1.8408451674543247 0.3651916781075087\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"3 False 1.7804740993232484 0.38471337579617837\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"4 True 1.7385770342697222 0.40162635970007393\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"4 False 1.7134963425557326 0.4206369426751592\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"5 True 1.6633978104538494 0.4345759847924807\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"5 False 1.7293920431926753 0.39923566878980893\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"6 True 1.5818110619191572 0.46477980779385364\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"6 False 1.6024532245222929 0.4565605095541401\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"7 True 1.5281916030269829 0.4891752032949625\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"7 False 1.5787282294984077 0.47312101910828025\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.fit(8)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 140,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD7CAYAAACCEpQdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXiU9bn/8fednSQECIQdEtkRZI0soa7VVq2nsiqCLG4o2J7T6q+2PdaeVtu6tPacagHBWtkUcAH3ra1LlbAFWaOIggn7DoGEJZB8f3/MxI7pABPIZLbP67rm0jzPk8n9JTM3D8/c8xlzziEiIrElLtQFiIhI3VPzFxGJQWr+IiIxSM1fRCQGqfmLiMSghFAXEIgmTZq4nJycUJchIhJRVqxYsdc5l+VvX0Q0/5ycHAoKCkJdhohIRDGz4lPt02UfEZEYpOYvIhKD1PxFRGKQmr+ISAxS8xcRiUEBNX8zyzSzhWZWZmbFZjbqFMddZmbvm1mJmRX52Z/j3X/EzNab2RXnWL+IiJyFQM/8JwPlQDNgNDDVzLr5Oa4M+Cvwk1Pcz1xgJdAYuA940cz8zqCKiEjwnLH5m1kaMAy43zlX6pz7GHgVGFP9WOfcMufcbGCTn/vpBPQB/sc5d9Q59xKw1nvfEga+3F3K2+t2hroMEakDgZz5dwIqnHMbfLatBvyd+Z9ON2CTc+5wIPdjZhPMrMDMCvbs2VPDHyVn457nV3HnnBX86tVCKir1OQ8i0SyQ5p8OlFTbVgLUr+HPqtH9OOemO+dynXO5WVm6MhRsKzcfYPXWEnq2bsCM/CLumF3AkfKToS5LRIIkkOZfCmRU25YBHPZzbF3cjwTBjPwi6icn8NztA/j197vx3vrd3DBtCbsPHwt1aSISBIE0/w1Agpl19NnWEyis4c8qBNqZme+Z/tncj9Sy3YeO8ebaHYzIbUNacgLj8nKYPiaXL3eXMmRyPht26e9nkWhzxubvnCsDFgAPmFmamQ0CrgNmVz/WzOLMLAVI9HxpKWaW5L2fDcAq4H+824cAPYCXam85cjaeXbqZk5WOsQOzv952xfnNeP6OgZRXVDJsaj6LvtwbwgpFpLYFOuo5CagH7MYzrjnROVdoZheZWanPcRcDR4E3gbbe/3/XZ/9IIBc4ADwMDHfO6dXcECo/WcmzSzdzWeem5DRJ+8a+C1o3YOGkPFo0SGHcX5fxQsGWEFUpIrUtoEhn59x+YLCf7R/heSG36usPADvN/RQBl9awRgmiN9fuYG/pccbl5fjd37pRKi9OzGPinBX85MU1bNl/hB9f2QmzU/6aRSQCKN4hxs3IL6JdVhoXdWhyymMyUhJ5Znw/RvRtzePvfcndz6/m+MmKOqxSRGqbmn8MW7XlIKu2HGTcwBzi4k5/Jp+UEMejw3twz5WdWLhyG+P+uoySIyfqqFIRqW1q/jFsZn4R6ckJDOvbOqDjzYwffrsj/3dDLz4pPsjQqYvYsv9IkKsUkWBQ849Ruw8f4/U12xnetzXpyTX7NM/BvVsx69Z+7C0tZ8iURazacjBIVYpIsKj5x6i5S7dwouKb4501MaBdY16amEe9pHhGTl+sTCCRCKPmH4M8453FXNo5i3ZZ6Wf+hlPo0DSdhZMG0aV5BhOfXcFfPtqEc8oEEokEav4x6K11O9h9+NTjnTXRJD2ZubcP4LvnN+c3b3ymUDiRCKHmH4Nm5hdxXpM0LulYO4F59ZLimTK6D7dfdB4zFxcrFE4kAqj5x5g1Ww/yyeaDjB2YfcbxzpqIizPu+975PHCdTyjcIYXCiYQrNf8YMyO/iLSkeIYHON5ZU2MH5vDUWG8o3BSFwomEKzX/GLK39Divr97B8L6tqZ+SGLSf8+2uPqFwU/L5+AuFwomEGzX/GDJ36WbKKyoZWwsv9J7JBa0b8PJdg2jRMIXxzyzjeYXCiYQVNf8YcaKikjlLi7m4Uxbtz2G8syZaNazHixPzGNCuMfe+uIbH3v1co6AiYULNP0a8vW4nuw4dZ3ze2b2p62xlpCTyzM0Xcn1ua55470t+PH+VQuFEwkDN3tcvEWtmfhHZjVO5tFPTOv/ZifFxPDKsB20zU/nDuxvYUXKM6WNyaZAavNcdROT0dOYfA9ZtK6Gg+ABjA0jvDBYz4weXd+RPI3uxcvNBhkxdxOZ9CoUTCRU1/xgwI7+I1KR4RuQGZ7yzJq7r1YrZt/ZjnzcUbuXmA6EuSSQmqflHuX2lx3l19XaG9WlNRhDHO2uif7vGLJiUR1pyAiOnL+HtdTtCXZJIzFHzj3Lzlm+h/GQl4+r4hd4zaZ+VzoJJeXRtkcHEZz9RKJxIHVPzj2InKiqZs6SYizo2oUPT+qEu5980SU9m3oQBXNXNEwr3P68WcrKiMtRlicQENf8o9m7hLnaUHGPcwJxQl3JKKYnxTB7VhwkXt2PW4mLumL2CsuMKhRMJNjX/KDYzv4i2malc1qXuxztrIi7O+O9ruvLgdd14//Pd3DB9sULhRIJMzT9KFW4vYVnRfsYOzCY+ROOdNTVmYA5/GZfLpj1lDJ68iM93KhROJFjU/KPUzPwi6iXGMyK3TahLqZHLu3hC4U5WOoZPVSicSLCo+Ueh/WXlvLJqO0P7tKJBvfAY76yJ7q08oXAtG9bzhMItVyicSG1T849C85Zv5vjJylr5mMZQadmwHi9MHMjA9o259yWFwonUNjX/KHOyopI5i4sZ1KExnZqF33hnTWSkJPLX8RdyQ24bnnjvS36kUDiRWqPmH2X+9ukutof5eGdNJMbH8fCwC/jJdzvzyqrtjHl6GQePlIe6LJGIp+YfZWbkF9G6UT2+3bVZqEupNWbGXZd14E8je7Fq80GGTs1XKJzIOVLzjyKf7TjE0q8ia7yzJq7r1Yo5t/X/OhTuE4XCiZy1gJq/mWWa2UIzKzOzYjMbdYrjzMweMbN93tujZmY++y83s0/M7JCZbTKzCbW1EPGMd6YkxnF9hI131kS/8zK/DoW7cfoS3lqrUDiRsxHomf9koBxoBowGpppZNz/HTQAGAz2BHsC1wB0AZpYILASmAQ2AG4A/mlnPc1mAeBwoK+flVdsY0rs1DVOTQl1OULXPSmfhpDzOb5nBpOcUCidyNs7Y/M0sDRgG3O+cK3XOfQy8Cozxc/g44DHn3Fbn3DbgMWC8d18mkAHMdh7Lgc+A8899GTK/YAvHToRfemewNE5PZu7t/wqF++UrCoUTqYlAzvw7ARXOuQ0+21YD/s78u3n3/dtxzrldwFzgZjOLN7OBQDbwsb8famYTzKzAzAr27NkTQJmx62RFJbMXFzOwXWO6NM8IdTl1pioU7o6L2zF7STETFAonErBAmn86UFJtWwngb4i8+rElQLrPdf+5wC+B48BHwH3OOb9v33TOTXfO5TrncrOysgIoM3b9/bPdbDt4NKLf1HW24uKMn1/TlQcHd+eDz3dz/bTF7FIonMgZBdL8S/FcrvGVAfhL3ap+bAZQ6pxzZtYFmA+MBZLw/IvgXjP7Xo2rlm+YmV9Eq4b1uKJreKd3BtOYAdk8Pe5CvtpbxpDJi1i/81CoSxIJa4E0/w1Agpl19NnWEyj0c2yhd5+/47oDnzvn3nHOVTrnPgfeAK6uedlSZf3OQyzetI8xA7NJiI/tyd3LujT9OhRuxNTFfPSFLheKnMoZu4VzrgxYADxgZmlmNgi4Dpjt5/BZwN1m1srMWgL3ADO8+1YCHb3jnmZm7fFMA632cz8SoJn5xSQnxHFDFI931kRVKFyrRvW4+ZnlCoUTOYVATxUnAfWA3Xiu2090zhWa2UVmVupz3DTgNWAtsA7Pmf00AOfcRuAW4HHgEPAh8BLwdC2sIyYdPFLOwpVbGdK7FY3Sonu8syZaNqzHC3f+KxTuD+8oFE6kuoRADnLO7cczv199+0d4XuSt+toB93pv/u7neeD5s6pU/s3zX4935oS6lLBT3xsKd//L6/jz+1+y5cARHh3eg+SE+FCXJhIWAmr+En4qKh2zFhfT/7xMuraInfHOmkiMj+OhoRfQJjOV37/zOTsOHmP62L5
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.plot(onecyc.lrs);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Questionnaire"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"For the questions here that ask you to explain what some function or class is, you should also complete your own code experiments."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"1. What is glob?\n",
|
||
|
"1. How do you open an image with the Python imaging library?\n",
|
||
|
"1. What does L.map do?\n",
|
||
|
"1. What does Self do?\n",
|
||
|
"1. What is L.val2idx?\n",
|
||
|
"1. What methods do you need to implement to create your own Dataset?\n",
|
||
|
"1. Why do we call `convert` when we open an image from Imagenette?\n",
|
||
|
"1. What does `~` do? How is it useful for splitting training and validation sets?\n",
|
||
|
"1. Which of these classes does `~` work with: `L`, `Tensor`, numpy array, Python `list`, pandas `DataFrame`?\n",
|
||
|
"1. What is ProcessPoolExecutor?\n",
|
||
|
"1. How does `L.range(self.ds)` work?\n",
|
||
|
"1. What is `__iter__`?\n",
|
||
|
"1. What is `first`?\n",
|
||
|
"1. What is `permute`? Why is it needed?\n",
|
||
|
"1. What is a recursive function? How does it help us define the `parameters` method?\n",
|
||
|
"1. Write a recursive function which returns the first 20 items of the Fibonacci sequence.\n",
|
||
|
"1. What is `super`?\n",
|
||
|
"1. Why do subclasses of Module need to override `forward` instead of defining `__call__`?\n",
|
||
|
"1. In `ConvLayer` why does `init` depend on `act`?\n",
|
||
|
"1. Why does `Sequential` need to call `register_modules`?\n",
|
||
|
"1. Write a hook that prints the shape of every layers activations.\n",
|
||
|
"1. What is LogSumExp?\n",
|
||
|
"1. Why is log_softmax useful?\n",
|
||
|
"1. What is GetAttr? How is it helpful for callbacks?\n",
|
||
|
"1. Reimplement one of the callbacks in this chapter without inheriting from `Callback` or `GetAttr`.\n",
|
||
|
"1. What does `Learner.__call__` do?\n",
|
||
|
"1. What is `getattr`? (Note the case difference to `GetAttr`!)\n",
|
||
|
"1. Why is there a `try` block in `fit`?\n",
|
||
|
"1. Why do we check for `model.training` in `one_batch`?\n",
|
||
|
"1. What is `store_attr`?\n",
|
||
|
"1. What is the purpose of `TrackResults.before_epoch`?\n",
|
||
|
"1. What does `model.cuda()` do? How does it work?\n",
|
||
|
"1. Why do we need to check `model.training` in `LRFinder` and `OneCycle`?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Further research"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"1. Write `resnet18` from scratch (refer to <<chapter_resnet>> as needed), and train it with the Learner in this chapter.\n",
|
||
|
"1. Implement a batchnorm layer from scratch and use it in your resnet18.\n",
|
||
|
"1. Write a mixup callback for use in this chapter.\n",
|
||
|
"1. Add momentum to `SGD`.\n",
|
||
|
"1. Pick a few features that you're interested in from fastai (or any other library) and implement them in this chapter.\n",
|
||
|
"1. Pick a research paper that's not yet implemented in fastai or PyTorch and implement it in this chapter.\n",
|
||
|
" - Port it over to fastai.\n",
|
||
|
" - Submit a PR to fastai, or create your own extension module and release it. \n",
|
||
|
" - Hint: you may find it helpful to use [nbdev](https://nbdev.fast.ai/) to create and deploy your package."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.5"
|
||
|
},
|
||
|
"toc": {
|
||
|
"base_numbering": 1,
|
||
|
"nav_menu": {
|
||
|
"height": "140px",
|
||
|
"width": "202px"
|
||
|
},
|
||
|
"number_sections": false,
|
||
|
"sideBar": true,
|
||
|
"skip_h1_title": true,
|
||
|
"title_cell": "Table of Contents",
|
||
|
"title_sidebar": "Contents",
|
||
|
"toc_cell": false,
|
||
|
"toc_position": {},
|
||
|
"toc_section_display": true,
|
||
|
"toc_window_display": false
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|