fastbook/clean/19_learner.ipynb

1377 lines
164 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 23:56:41 +00:00
"# A fastai Learner from Scratch"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMAGENETTE_160)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('/home/jhoward/.fastai/data/imagenette2-160/val/n03417042/n03417042_3752.JPEG')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = get_image_files(path)\n",
"t[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('/home/jhoward/.fastai/data/imagenette2-160/val/n03417042/n03417042_3752.JPEG')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from glob import glob\n",
"files = L(glob(f'{path}/**/*.JPEG', recursive=True)).map(Path)\n",
"files[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAANUAAACgCAIAAAAUzb6mAADk7ElEQVR4nHT9WbAt6XUeiK31jznu8cznnjsPNQMoACQAkiAlkiJFSe6W2WyJVsshdfSDw9H2ix1uvyjC4SfbHR1+aClshUMd0WoP3W2FJtKkmrMAkAWABApAoepW3aHueOaz5xz/afkhzzl1ScsZO87dO+/euXNnfv+a17fw//R//i+FEEIIxhgAIKIQgnPOOUfE7i8ihhCIKITgnBNCEJExBgDiOE6ShHPeti0AMMa693vvjTHOubquGGPIiDHGGEMkRAQAAEYBERkRNrWp69b7wJks2znnnL2ydefTnQARIWL3LQAQQmhdGcdpFMU+gLPgfXAuGBdca6SUOpJKcGRA3hF5IkIBUkrGeQjgvQc4P2EA0FprqQC6N1P3yx16zjkRNU1T13UIQWsdx7ExhnMuhOiuiXOuOzERfPcyBIeIDAAgAIQQAgKVZZmm8U//1E/9F//Ff76zs3P//v22Cd57a6211ntHRIQAEACDkDJOdJqmaZomeZqmaRRFm7FujbPWAhdKayaVD8E511rjnHG29d4CBQaByEMgNRo45wBASimlVkJqHUVRJFAQUbGqptPp9as3rLUAAADOUVGthmsDAhvISnRtsdwcDtvlMpWxFJoYbwO0CAa4o9BySUQAgTHWoQUwEJGUEgAQ4bP9AESei34Ivq3rqlhZ0whEJCLvfQgBAIQQHYwukde9oQNfCKHDBBFd4rX7q7UGACLq/l6+v7sxHf4QESB0H6mqBogxxhG5d+ScZ4wLyYbJsDtC99kQgjGGMaaU6i7QJS47LOZxLoTqPoIMMCARedsqpYU8f1vwAUIgCgCAgCF053B+HM5ld26CcYAAgUIIHAmRcc5coBB8CCEED0AAFIK31rRtc3kyHYCJgnMeyRvTlmVZlmVVF6ZumqZyzk2mZ/PpDJF+7dd+raqq999//9vf/jbnPE36QgilVJpFURTpKJKKc85v3LjGxPlKDBi6i4yITbGy1rbWegNFUxKh9c45Nxj1leRJnEVaRUoIwRgSQMA4MsZ47xE4AEBAIYTkvGnq4MmZypvWu4YCdV+RRLEQGAm+LFZC+HyQq2BNU5FzjtrgvA/QEjgheJyqSDOuu5vOGGMMOlSEEIQ8hwdj2O0EACJRVCulVJrqLJWcobj8wKX06sDXHfQSTN12ecWJqJOCF8LsM8x1aDbGtG1rrS3qqrtwjMOl/ENEFUchBApIAT0ED0REnohdXOhL6Hdn1R25O9XuZfelnsCZ1rvgPTEmGAqAAABFuRRCKCWk4JyjYCiEYBw8hUtZTgGJCLFhjMVxbNExDkjgvbPOd5ePOHW/yDnnnSMiB4GCy9K0bduyLOu6bpqmLMvValVVVTGfGGPqum6aqmkaY4wzrXMuiiKkoLWOomhtff0LX/iCMWZnZ6fTPJwj55wQvHfe+0BusZoJIYRkKDiRv/y94ziOVJzwjEkhhGScI2eIyDkD8iE4hMAgkHfkrfdWC5DBSwQhGAIHQMZACIrTmAgVElk77KXBnyso17pBnulYx4rSTG2sD08P4PDp075OyBqOuLm5MdrahUifLVYn04nlsrvgHWwYY93ZIiAyuriJgIjd+o8i0etlg16uJOcIohN43bXmnEspuyvS/eDLZdfhoLsT3vtXRZQxhojqur7E3+V+a20Hmksxeala0zQOAYIH7wOivfy6c3l5sV2IKF5VVXdwROze030RegAAZ71zAZELrkKgENxw2Oecc9apaeeddd50eqH7LZxzFOc45hzqagUAQjKlFEe0YJ1zROTBW2vbtj0H07luDc+ePbPWdsvMGNNp57ZtR4OcMRYpPRj0+4MreZ6nUSyEcM4xxrx1aZoioo4j5AwAymrOOeeyM4HgXBeBX1sfCcGFVt2yuRS0YIzgqgNrCMEH6G4zY0CBBKFgTCulJReMISOFvFO+UZRwLkIIARhjwhofAq1WcSzF9vqo28sYWy0KKTnyEA3S8bg/7Gd2OZsAmGpF1omYlBBZogKXFKxpK573iboF0NlsFAICBCICIkQGBAT+Upz5YKxr6gaq0lnTCCnl5Z2+NPuIiHP+F+Rf97xpmktEdnLOWiulbJrmVah1QOGcZ3F+gWNijHFxbr2VdQXEGBMACAyRMwCGAgU7V6Yd0Dt8dKdHhCFQpwQvsa4jxRgTjHuw3tgQAgJHhLoqhRBScs45BE/giQIReWO994iglELEDpScc6VEXdfLeem9b01dVVXTNCGEqiov5FnToa1tW+dcv9+PoijLsq2NcZqmSZLEcayUCtSEELz3EAiAdecJPqxWZdu229vbgVyWJ1LoPOv3+31gmjEmpVRKKaWEOle6VV0Q666zD0CeAhFDQCkFl7KTGjZ4CAGRJONRrKTATOs8i3pZmsRSSc4Bpe/MKsWl8I7quq0bEwI1xjkX0PumlLEUzjmGDDgb97Omqaq6ivNIcVYXK980nIFvXCSkFGJydnI0OTXIeJKMh/0pyUvhxTlHBp3d5b29VKTnEhGRMZb1Yx1JDlQbU9Wrz+RfB4tX7bwOdh0ILs3/S0B0TkZnO3f38hIoHXY72zxw7G48ke+0ZwdNrXUIgQKEQD5Y5wwAAwgoZIe/S0nTHTaKokvcw8VGRNY7xZUQPITAATszDpGXZYnkkASJ7kd5wAAAWdIry7Ku6tVyXtdluVq1bY1I0+l0OpvMFgsfbHdkJrhSKo61lDJJkvWN4XA47Pf7cRxfClFEtNZ2kq9pi9Zg064u3DglpZRccMGA+HhtyBi7d/f17oevr6+PRqPFYrG5tUVEBOfX9/xXs6DjCM+tJ7pUTZzzYLE7sBACETlHIYQUDMknWg56yaiXJbHmGLxprLUREXlAREBmwLHgwTsi0JyBBwDw9lxTITjGGKcQghOSb61v5L346OBguVwuF0UihHUBGgOOmuAt5/041klMxbn4QuSICARAeOEockSOGLynS2CU1bxpOTlblUVTl4Ix1sHo0qXoQPbq80v5BwCdK3ppBV6aZZ05eGkgXkKksaa7vudWAgHnHDFsbe+1bVvXbVU1oW6sNyFAABmc7X5QCAGAOqnJGPPeXqKfMdaZnd3t735rcB5CYMBa0zrnBnnunGvbej5brlarqiqstY4CtmE2m62KhbfG2KauKyCvlHr9zde0gkGu43SY9/tpmmZZlmRpv5dfSvROsDnbWFMvl519qZRScSTzLO7sFmKjEALQpXPGOuNtsVgI5EkaJWk6XywGw954PG7bdrFavqp/gDPGGGciilMpZScXubp4wjkQf/VkQnAIEDxAcJ6HzpMLznME9MgIXGOdcy4QIvoAIYBSETJe1o0H8i5YHxrjrA2EgTHGjVWxioRMsp5QMhBDpgGlDax1ThHPkjyLohbRAJ4tVsD1K3ZapzMDECKc+xIADIAR+S6UIeUAkYIzDHiktOhk0rmLdP4BICKtdXddXo2GdPe7w9+lsu7Ux+VN6kQmAHSCs3sOSIwh51wIJiTjnGd5whiz1gKE0EVGEDjnHM7FML4S9HnVDLj0fM/dII0YyJrWNq23zllbrZZVUU6nU29t01Srqqzr2jlDCIyxO1fvtE3BwfcG+bC/k2XJYNgbDgebm+tnZ2fz1RIYEkDVNsbYVbFcLqZKqTiOoyjSWsdxorUWQly/fvXSPOjM5S6GomIdQujcHCIKPngKwXshxPPnz9fX19/+hbeV1NPZbLlaSalTlaLgl/pXSskER8ThcMjkZ3ExOlciQHT+EhlqyTlnnIFAAGc5hOCoWLRm1TAMEFxwVlBojXPOAbAADJAJpQUXxjMTfOuhcWQ8GBcIkTHMdSIiXdbV8dlcSl62Phusr+94DKSZSLJUxvHSmrquUWqd9cFHF8oXO73nPffehhCklFKJDgmdwyqEaF2JgEAA4ICkmM/n3Tru5PmlAdcd69Lsu0SAlLKLFYWL7VJrv+IBUfeVRCQ4j6JoOp1kWYZAq+Wy38/LVfHxR/eTJKnrtq5ryRlQaNu2l+fOWMZ42zTOuTRNOedN0wRyABC6gCFnRL6q6rquvfd
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=213x160 at 0x7FA45AC27D50>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im = Image.open(files[0])\n",
"im"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([160, 213, 3])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im_t = tensor(im)\n",
"im_t.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#10) ['n03417042','n03445777','n03888257','n03394916','n02979186','n03000684','n03425413','n01440764','n03028079','n02102040']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lbls = files.map(Self.parent.name()).unique(); lbls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'n03417042': 0,\n",
" 'n03445777': 1,\n",
" 'n03888257': 2,\n",
" 'n03394916': 3,\n",
" 'n02979186': 4,\n",
" 'n03000684': 5,\n",
" 'n03425413': 6,\n",
" 'n01440764': 7,\n",
" 'n03028079': 8,\n",
" 'n02102040': 9}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v2i = lbls.val2idx(); v2i"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Dataset:\n",
" def __init__(self, fns): self.fns=fns\n",
" def __len__(self): return len(self.fns)\n",
" def __getitem__(self, i):\n",
" im = Image.open(self.fns[i]).resize((64,64)).convert('RGB')\n",
" y = v2i[self.fns[i].parent.name]\n",
" return tensor(im).float()/255, tensor(y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(9469, 3925)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_filt = L(o.parent.parent.name=='train' for o in files)\n",
"train,valid = files[train_filt],files[~train_filt]\n",
"len(train),len(valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 64, 3]), tensor(0))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ds,valid_ds = Dataset(train),Dataset(valid)\n",
"x,y = train_ds[0]\n",
"x.shape,y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAACMCAYAAABcUNbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO19eayc13Xf737r7MvbV27iIlGkSGuxZclK5CS2ZDuJjNhN6sQx7NRJ6hRokKBIWycp0hZtDBQBCnRH2zRwkzRI09rZvMixK3mRtVk7xU0kH/keH9++zD7zbf3j/O6d4ZM4VOyoBvrmAOSQ33zzfXc7557zO8tVSZJgQLuDrO93Awb0/44Gk72LaDDZu4gGk72LaDDZu4gGk72LaDDZu4i+r5OtlHqHUuoJpVRLKXVNKfXbSim75/tZpdSjSqlFpVSbn59VSs3c4HmWUuqrSqlEKfXRHd/9ulLqG0qpCr+f2fH9Pl5/oz//bse9H1dKnWWbziilfqZPH48qpepKqXDH9UeUUl9QSi0ppRpKqVNKqV9WSqm/zhj+dej7NtlKqVkAXwFwFsBdAD4F4BcB/Iue20IA/wvAjwE4BOAnARwG8Oc3eOw/AdC4wXc+gD/b8fxemgcwuePP3+J3f9TT7g8C+K8A/iOAEwD+M4DPKqXe9wZ9zAD4YwBfe4P3PQjg2wB+AsAxAL8D4DMAfu0G7fveKUmSt+QPgMcA/BcAvwlgCcAGgN8DkOX3/xLAAgCr5zd/D0Bd33OD5z4CIAFQ3HH93QCuABjm9x+9we8f5Pczb6IPfwDg1I5rTwD4wx3X/ieAx97g9/8Nsig+DiB8E+/7NwC+81bNyVvN2R8GMMQB/mkAH0R35d4P4NEkSeKe+78EIAPgbW/0MKXUCICfBfBckiTbPdfHAfx3AB9LkmT9b6LhfNeHAPynnmsegHvYzl76EoB7d2xBH+O9v/LXeG0RwNp32+ab0Vs92VeSJPmVJEnOJEnyJYg4fC+/m4RwfC8t9XxnSCn1P5RSDQCrAKYAPNzznQXhwN9NkuSxv8G2fxxADFlEmkYAODdotw9Z2FBK3QYRy387SZLmm3mZUupBAB8B8G+/l0b3o7d6sl/Y8f+rAMb73J/s+NT0KxBu1/viH/Vw0acBpAD80++hndcRlaRfAPDHSZJs/jV+miilfIhY/40kSV55k++7F8DnAfxWkiQ30ke+d3qr9+wd134DwBz//TiEG3u/PwCZ6Hf1ee4k73m45z0RRJnTfxJeO/Pd7NkAfpj33LvjugcggGwXvdd/DkALgA1gH3/b256o59qn36A91Z3X34o/zt/Mkvmu6FsAflYpZfXs2w9DtOnn+/xOSyOfn58AkN1xz8sAfh2iyX839IsAXkqS5Mnei0mSdJRSzwB4CMBne756GMCTSZJESqmrAI7veN4jEMlzEsCyvqiU+gBECvxmkiS/81229c3T95GzZwFUIGbM7QB+HMA6gM/03P8hAB+DmCZ7AfwIZJHMA8j3effrtHEAeyCD/Ul+/17+f2jHfWMAOgB+6QbP/iCEQ38ZwBEAv8r/v69Pez6OHdo4xKzrAPhnACZ6/oy+ZXPy/Zps/v9eiCnTgig5vw3A7vn+RwE8CWCL91wA8B8AzN7k3W802b/H6zv/fHzHff8IQA1A4SaTd46TdXbnu97kZD92g/bM9XvW9/JH8cUD2gU0wMZ3EQ0mexfRYLJ3EQ0mexfRYLJ3EfUFVVKTo2LHpMfgFicAANmcfNdpbAAA3PQwKtvik3DtAgDA3n8UrdFhAICnBPtwkzSipAIAsDxPXp4pshEKCZddYnXk3dubiGy56I4fkN8BcFtb8puUvMvx+dmpYOn8aQDAxB65P1soIe5UAQDp4ggAoL22gKBRBwAMT+4HAKgtgbrPPfY52FXxQ/gpaRsyaZQyGbnPTcm1chkAMD48gUxe/p2ypJ/loWFsBgKH11YFad2sNrCnZHHcpD3raysAgCunn8TRh/8+AKB6+Kfk3ck8MlV5Z8WVZ1mdBtrXTgEAXv2GIKrzZwV7+sAjH8LKhnh2H/vDf31Df/iAs3cR9eVsNyUopMrmEDvid2h12gAEMgKAyG4jOyar0IZwbJRLoPyY1/iiJIYiKqo8se0d1+J3FmDJgkxsuce1begl6lhyn63kOgA45Hr+F77roDRUNv/Wz0gc+bdny2fiOIAlP/L4nUpJ+4vD40jlRXTlCiIJ3FIJDn0ue2b3yrtH5D1JK0Q7kPaW0iJhvv30U1isCEdHNRmrfKGAw6MibQKOwVB5CACw6PiobgqX5zkuKvLhp2Qssxx31y3Cn5Y+7/mA/PbFMXnm4pqDk3f9KG5GfSfby+UBAHEqhciRF7WDSAaD33WsDoJIRK/jSGMtVYFlcaEgx09AEb9RHGzFSVSx0nNtqFWtwsvKM1xG6jhKwVEU7ZxlS3+HBLm0iFKfA5RyLESR3O/pVWHZSPgyn5OduPK7qf2HMTss4ltPdttNw7ddAMC1hXkAwKXnRZx2tqsol0oAgHuOnwQArG9uImBfUmlpfyblYXxY7ktaspWtbYo4L5fHUVm6Iu/nNhe5GfgWtzqOkRUoZFISSZUbGZPfjsk2FAZruDz3MkfuR3AjGojxXUR9ObvREOXA9gLky7LSbYuKw+pVAEDYrsC1hNsjtAAAXrGIdGFU7nco4pMESUKOTnx+2nxmDJscalF4N2p1eFTkPKuHs7nSXX46tnznxt2tIE3O9pRCTI5OaUlg21DkaF8/Iy1tnDx0O7ZWRVm7dEG4uB53uXZxWZTSelskWSaVRRJQVKelrb5jo0POtmN5Zy7j48H73w4AOHdKuP3Rx58CAIyP78G5OeFsr74IAIhLh6Es2SjTHA+lQiSJjHPsyXPLJYnxCNZW8b+/xhiL3/oUbkQDzt5F1JezHXKIQoA4Ei5PItl4XUf2sXSUR9imSRXLyutsbsMabvMZAQAg5SRIuGlHXKE2OdFWCs4Ozk45Dlx+72kFDcr8eydne3HXfPP46TsWwlA/g893bKPVZbRpRx2i02xifklMLyctukY5n0LOld/mfOlzJaD+0mxjfFjMsff84DsAAE898yQurIuC5lFy5TM+sr4M9UhRnht0ZFyGciX4oZiMqIg08YdOIFDbpl8AEFshLN3XrMuxpY4Uu1BbNw9NGHD2LqK+yyFsi6GuXAvNhqwmn/uo1oaRKiCOZZUGbeHmVivAaIaathJOScIOXF++tyH3gxq7bXmmIbbqcq7mYo+cZ8Eye7aj9/EeDtfXXHKs59iw2E6tjduOA8tP8RqlSCKfLddFOkfOU9LfoFHF/ScPAQBW5y8BAC5ekz3Wt3yMU3sfzcu+PzFcRMA9tZAWE2l0KANeQiGTNn2R8cygnJZ/V1fm5BkHHYSUdHZHj6NvrJgkjPkM2ddfPfUKllfmcTPqz/tUrsJmCy5ENCkqM86YdAQJENRk8OKUKC7+5ChiRQUtqAEAGuESfCWvS8k4wlNjbLQHZWsxnmLnbFgOO+zKc11E8NhRT8lC8TlonuoqaDZtdduxjbjPQraOwPXh0fr3OaBF2uUnb7kFLz8vqNRmS55RcDuYGhMbemJY3jkzJYrR1MwMbpkUpNB1pVM538cwzb3hnDBG3gPWFmUyFhdECXNc+W6z3UE6L8pvsHBW+m5vIxPSdNUWY6KM6RoSz7hw8QwA4Et/9Xl4Q9K/fjQQ47uI+oMqeeHeqFNBVBUQIEnkJ7WmKGXFbAo2laBUXla3PVyCa4t4S6pirpRSCSxKbwSCTbtZiiPbASxyI8WnYyu4XMEpinsXLXi2PMShaaI5xAEQ1AlU5AXASMdAQs7OkC0algNLK3JkmzSRvCPTIyhpqUMF9MSth5BLi7QpZ+TaZFn6dmBmFBlKjEtnXpK2Jk2UqNjmIEptWuXwygvPAQC2tuVaIS/SYqVSw8joFABgde48ACAOVpF3DgIAWhCppiIFW8kYdbjt6I5MjM9iM67gZjTg7F1EfTnbAnHqXA4O97VOWzgqqMkKrbc7SHuidGhzwm+2kCnJXt2KhLOd9VWojnB
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"show_image(x, title=lbls[y]);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def collate(idxs, ds): \n",
" xb,yb = zip(*[ds[i] for i in idxs])\n",
" return torch.stack(xb),torch.stack(yb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([2, 64, 64, 3]), tensor([0, 0]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = collate([1,2], train_ds)\n",
"x.shape,y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DataLoader:\n",
" def __init__(self, ds, bs=128, shuffle=False, n_workers=1):\n",
" self.ds,self.bs,self.shuffle,self.n_workers = ds,bs,shuffle,n_workers\n",
"\n",
" def __len__(self): return (len(self.ds)-1)//self.bs+1\n",
"\n",
" def __iter__(self):\n",
" idxs = L.range(self.ds)\n",
" if self.shuffle: idxs = idxs.shuffle()\n",
" chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]\n",
" with ProcessPoolExecutor(self.n_workers) as ex:\n",
" yield from ex.map(collate, chunks, ds=self.ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([128, 64, 64, 3]), torch.Size([128]), 74)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_workers = min(16, defaults.cpus)\n",
"train_dl = DataLoader(train_ds, bs=128, shuffle=True, n_workers=n_workers)\n",
"valid_dl = DataLoader(valid_ds, bs=256, shuffle=False, n_workers=n_workers)\n",
"xb,yb = first(train_dl)\n",
"xb.shape,yb.shape,len(train_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([0.4544, 0.4453, 0.4141]), tensor([0.2812, 0.2766, 0.2981])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats = [xb.mean((0,1,2)),xb.std((0,1,2))]\n",
"stats"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Normalize:\n",
" def __init__(self, stats): self.stats=stats\n",
" def __call__(self, x):\n",
" if x.device != self.stats[0].device:\n",
" self.stats = to_device(self.stats, x.device)\n",
" return (x-self.stats[0])/self.stats[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"norm = Normalize(stats)\n",
"def tfm_x(x): return norm(x).permute((0,3,1,2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.3732, 0.4907, 0.5633]), tensor([1.0212, 1.0311, 1.0131]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = tfm_x(x)\n",
"t.mean((0,2,3)),t.std((0,2,3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Module and Parameter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Parameter(Tensor):\n",
" def __new__(self, x): return Tensor._make_subclass(Parameter, x, True)\n",
" def __init__(self, *args, **kwargs): self.requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3., requires_grad=True)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Parameter(tensor(3.))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Module:\n",
" def __init__(self):\n",
" self.hook,self.params,self.children,self._training = None,[],[],False\n",
" \n",
" def register_parameters(self, *ps): self.params += ps\n",
" def register_modules (self, *ms): self.children += ms\n",
" \n",
" @property\n",
" def training(self): return self._training\n",
" @training.setter\n",
" def training(self,v):\n",
" self._training = v\n",
" for m in self.children: m.training=v\n",
" \n",
" def parameters(self):\n",
" return self.params + sum([m.parameters() for m in self.children], [])\n",
"\n",
" def __setattr__(self,k,v):\n",
" super().__setattr__(k,v)\n",
" if isinstance(v,Parameter): self.register_parameters(v)\n",
" if isinstance(v,Module): self.register_modules(v)\n",
" \n",
" def __call__(self, *args, **kwargs):\n",
" res = self.forward(*args, **kwargs)\n",
" if self.hook is not None: self.hook(res, args)\n",
" return res\n",
" \n",
" def cuda(self):\n",
" for p in self.parameters(): p.data = p.data.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConvLayer(Module):\n",
" def __init__(self, ni, nf, stride=1, bias=True, act=True):\n",
" super().__init__()\n",
" self.w = Parameter(torch.zeros(nf,ni,3,3))\n",
" self.b = Parameter(torch.zeros(nf)) if bias else None\n",
" self.act,self.stride = act,stride\n",
" init = nn.init.kaiming_normal_ if act else nn.init.xavier_normal_\n",
" init(self.w)\n",
" \n",
" def forward(self, x):\n",
" x = F.conv2d(x, self.w, self.b, stride=self.stride, padding=1)\n",
" if self.act: x = F.relu(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l = ConvLayer(3, 4)\n",
"len(l.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([128, 4, 64, 64])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xbt = tfm_x(xb)\n",
"r = l(xbt)\n",
"r.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Linear(Module):\n",
" def __init__(self, ni, nf):\n",
" super().__init__()\n",
" self.w = Parameter(torch.zeros(nf,ni))\n",
" self.b = Parameter(torch.zeros(nf))\n",
" nn.init.xavier_normal_(self.w)\n",
" \n",
" def forward(self, x): return x@self.w.t() + self.b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([3, 2])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l = Linear(4,2)\n",
"r = l(torch.ones(3,4))\n",
"r.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class T(Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.c,self.l = ConvLayer(3,4),Linear(4,2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = T()\n",
"len(t.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=5)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.cuda()\n",
"t.l.w.device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Simple CNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Sequential(Module):\n",
" def __init__(self, *layers):\n",
" super().__init__()\n",
" self.layers = layers\n",
" self.register_modules(*layers)\n",
"\n",
" def forward(self, x):\n",
" for l in self.layers: x = l(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AdaptivePool(Module):\n",
" def forward(self, x): return x.mean((2,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def simple_cnn():\n",
" return Sequential(\n",
" ConvLayer(3 ,16 ,stride=2), #32\n",
" ConvLayer(16,32 ,stride=2), #16\n",
" ConvLayer(32,64 ,stride=2), # 8\n",
" ConvLayer(64,128,stride=2), # 4\n",
" AdaptivePool(),\n",
" Linear(128, 10)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = simple_cnn()\n",
"len(m.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5239089727401733 0.8776043057441711\n",
"0.43470510840415955 0.8347987532615662\n",
"0.4357188045978546 0.7621666193008423\n",
"0.46562111377716064 0.7416611313819885\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([128, 10])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def print_stats(outp, inp): print (outp.mean().item(),outp.std().item())\n",
"for i in range(4): m.layers[i].hook = print_stats\n",
"\n",
"r = m(xbt)\n",
"r.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def nll(input, target): return -input[range(target.shape[0]), target].mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1.2790, grad_fn=<SelectBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()\n",
"\n",
"sm = log_softmax(r); sm[0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.5666, grad_fn=<NegBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = nll(sm, yb)\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1.2790, grad_fn=<SelectBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()\n",
"sm = log_softmax(r); sm[0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.rand(5)\n",
"a = x.max()\n",
"x.exp().sum().log() == a + (x-a).exp().sum().log()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3.9784, grad_fn=<SelectBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def logsumexp(x):\n",
" m = x.max(-1)[0]\n",
" return m + (x-m[:,None]).exp().sum(-1).log()\n",
"\n",
"logsumexp(r)[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1.2790, grad_fn=<SelectBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm = log_softmax(r); sm[0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def cross_entropy(preds, yb): return nll(log_softmax(preds), yb).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SGD:\n",
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
" def step(self):\n",
" for p in self.params:\n",
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
" p.grad.data.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DataLoaders:\n",
" def __init__(self, *dls): self.train,self.valid = dls\n",
"\n",
"dls = DataLoaders(train_dl,valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Learner:\n",
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
" for cb in cbs: cb.learner = self\n",
"\n",
" def one_batch(self):\n",
" self('before_batch')\n",
" xb,yb = self.batch\n",
" self.preds = self.model(xb)\n",
" self.loss = self.loss_func(self.preds, yb)\n",
" if self.model.training:\n",
" self.loss.backward()\n",
" self.opt.step()\n",
" self('after_batch')\n",
"\n",
" def one_epoch(self, train):\n",
" self.model.training = train\n",
" self('before_epoch')\n",
" dl = self.dls.train if train else self.dls.valid\n",
" for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):\n",
" self.one_batch()\n",
" self('after_epoch')\n",
" \n",
" def fit(self, n_epochs):\n",
" self('before_fit')\n",
" self.opt = self.opt_func(self.model.parameters(), self.lr)\n",
" self.n_epochs = n_epochs\n",
" try:\n",
" for self.epoch in range(n_epochs):\n",
" self.one_epoch(True)\n",
" self.one_epoch(False)\n",
" except CancelFitException: pass\n",
" self('after_fit')\n",
" \n",
" def __call__(self,name):\n",
" for cb in self.cbs: getattr(cb,name,noop)()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Callbacks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Callback(GetAttr): _default='learner'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SetupLearnerCB(Callback):\n",
" def before_batch(self):\n",
" xb,yb = to_device(self.batch)\n",
" self.learner.batch = tfm_x(xb),yb\n",
"\n",
" def before_fit(self): self.model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TrackResults(Callback):\n",
" def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]\n",
" \n",
" def after_epoch(self):\n",
" n = sum(self.ns)\n",
" print(self.epoch, self.model.training,\n",
" sum(self.losses).item()/n, sum(self.accs).item()/n)\n",
" \n",
" def after_batch(self):\n",
" xb,yb = self.batch\n",
" acc = (self.preds.argmax(dim=1)==yb).float().sum()\n",
" self.accs.append(acc)\n",
" n = len(xb)\n",
" self.losses.append(self.loss*n)\n",
" self.ns.append(n)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 True 2.1275552130636814 0.2314922378287042\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 False 1.9942575636942674 0.2991082802547771\n"
]
}
],
"source": [
"cbs = [SetupLearnerCB(),TrackResults()]\n",
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)\n",
"learn.fit(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Scheduling the Learning Rate"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LRFinder(Callback):\n",
" def before_fit(self):\n",
" self.losses,self.lrs = [],[]\n",
" self.learner.lr = 1e-6\n",
" \n",
" def before_batch(self):\n",
" if not self.model.training: return\n",
" self.opt.lr *= 1.2\n",
"\n",
" def after_batch(self):\n",
" if not self.model.training: return\n",
" if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException\n",
" self.losses.append(self.loss.item())\n",
" self.lrs.append(self.opt.lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 True 2.6336045582954903 0.11014890695955222\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 False 2.230653363853503 0.18318471337579617\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='12' class='' max='74', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 16.22% [12/74 00:02<00:12]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrfind = LRFinder()\n",
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])\n",
"learn.fit(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD/CAYAAAD2Qb01AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3ic1ZX48e8ZjXqvtixbkiu2McGN4tBMC21TCCSB1N0UErIkm+wmYTdZFnaXDWx+yW42jQSWJCQhbAiQAAmQhGIMGAw2tsG9SnJX79LU8/tjikbSjDTqI8/5PM88jN736tUdWZy5c95z7xVVxRhjTPJwTHUHjDHGTC4L/MYYk2Qs8BtjTJKxwG+MMUnGAr8xxiQZC/zGGJNknFPdgXiUlJRodXX1VHfDGGOmlc2bNzeqaunA49Mi8FdXV7Np06ap7oYxxkwrIlIb7bileowxJslY4DfGmCRjgd8YY5KMBX5jjEkyFviNMSbJWOA3xpgkY4HfGGMSUEevhxNtvfj94790vgV+Y4xJQL96rY5z73oOt88/7te2wG+MMQnIEwz4qSnjH6Yt8BtjTALy+PykOIQUh4z7tS3wG2NMAnL7/DgnIOiDBX5jjElIHq+SNgFpHrDAb4wxCcnr95PqtMBvjDFJw+Pzk5piqR5jjEkabq9OSEUPWOA3xpiE5PH5pzbHLyK/EpHjItIuIntF5NMx2v1YRDojHi4R6Yg4v05EeiPO7xmvF2KMMacSj8+Pc4pTPXcB1aqaB7wHuFNEVg1spKqfU9Wc0AN4CPjtgGa3RLQ5bUy9N8aYU1Qgxz+FI35V3aGqrtCXwcf8ob5HRLKB64AHxtRDY4xJQh5fAuT4ReRHItIN7AaOA08N8y3XAQ3A+gHH7xKRRhF5RUTWjqSzxhiTLKY8xw+gqp8HcoELgMcA19DfwSeAX6hq5NJytwLzgArgXuBJEYn6yUFEbhKRTSKyqaGhId5uGmPMKcHj85PqTIByTlX1qerLwGzg5ljtRGQOcBHwiwHfv1FVO1TVpaoPAK8AV8f4Wfeq6mpVXV1aWjqSbhpjzLTnToRUzwBOhs7xfxzYoKoHh7mOAhPzlmaMMdOYx+vH6ZiiwC8iZSJyg4jkiEiKiFwB3Ag8P8S3fRz4+YDrFIjIFSKSISJOEfkIcCHwpzH03xhjTkken5+0CUr1OONoowTSOj8m8EZRC3xJVR8XkUpgJ7BUVesARGQNgVTQwDLOVOBOYDHgI3CT+H2qarX8xhgzgNc/cameYQO/qjYQyNdHO1cH5Aw49iqQHeM6Z42um8YYk1zc3imu4zfGGDO5pnwClzHGmMkVqONPgHJOY4wxk8PjU5w24jfGmOThtlSPMcYkF6+leowxJnn4/IpfsRG/McYkC4/PD2B77hpjTLJwhwK/jfiNMSY5eLyhwG85fmOMSQoeX2A1exvxG2NMkvBYqscYY5JLX+C3VI8xxiSFUKpnyrdeNMYYMzks1WOMMUkmVM7ptFSPMcYkh1A5p6V6jDEmSXj9wXJOm7lrjDHJwWbuGmNMkkmImbsi8isROS4i7SKyV0Q+HaPdX4uIT0Q6Ix5rI85Xi8gLItItIrtF5LJxeh3GGHPKSJSZu3cB1aqaB7wHuFNEVsVo+6qq5kQ81kWcewjYAhQD3wAeEZHSUfbdGGNOSQlRzqmqO1TVFfoy+Jg/kh8kIouAlcDtqtqjqo8CbwPXjeQ6xhhzqnMnysxdEfmRiHQDu4HjwFMxmq4QkcZgSug2EXEGj58OHFTVjoi224LHjTHGBIVG/FNezqmqnwdygQuAxwBXlGbrgWVAGYGR/I3AV4PncoC2Ae3bgtccRERuEpFNIrKpoaEh3m4aY8y0502QHD8AqupT1ZeB2cDNUc4fVNVDqupX1beBfwOuD57uBPIGfEse0EEUqnqvqq5W1dWlpXYbwBiTPBJ1By4n8eX4FQglqXYA80QkcoR/ZvC4McaYoPCSDY4pyvGLSJmI3CAiOSKSIiJXEEjhPB+l7VUiMiP4fDFwG/A4gKruBbYCt4tIhohcC7wDeHT8Xo4xxkx/Hu/Up3qUQFrnCNACfBv4kqo+LiKVwVr9ymDbS4G3RKSLwM3fx4BvRlzrBmB18Dp3A9erqiXwjTEmgsfnJ8UhpEzQiN85XINgYL4oxrk6AjdtQ19/BfjKENeqAdaOtJPGGJNMPD7/hJVygi3ZYIwxCcfj0wlL84AFfmOMSTgen3/CavjBAr8xxiQcj88/YZuwgAV+Y4xJOG6f31I9xhiTTDw+tVSPMcYkE4/XRvzGGJNUvH4/qU7L8RtjTNJwWzmnMcYkF4/XT6rDAr8xxiQNj89SPcYYk1Q8Vs5pjDHJxXL8xhiTZLy2ZIMxxiQXW53TGGOSjMenOG3Eb4wxycPW6jHGmCQTWJbZUj3GGJM0bK0eY4xJMh6/kuq0wG+MMUlBVRNjApeI/EpEjotIu4jsFZFPx2j3CRHZHGx3RES+JSLOiPPrRKRXRDqDjz3j9UKMMeZU4PMrqpDqmPoc/11AtarmAe8B7hSRVVHaZQFfAkqAc4BLga8MaHOLquYEH6eNst/GGHNK8vgUYEJTPc7hm4Cq7oj8MviYD2we0O6eiC+PisiDwMVj7aQxxiQLt88PMPWpHgAR+ZGIdAO7gePAU3F824XAjgHH7hKRRhF5RUTWDvHzbhKRTSKyqaGhId5uGmPMtOYJBv6EKOdU1c8DucAFwGOAa6j2IvI3wGrg2xGHbwXmARXAvcCTIjI/xs+7V1VXq+rq0tLSeLtpjDHTmjeU6kmEET+AqvpU9WVgNnBzrHYi8j7gbuAqVW2M+P6Nqtqhqi5VfQB4Bbh6dF03xphTj2cSUj1x5fhjfF/UkbqIXAncB1yjqm8Pcx0FJu7zjDHGTDOhHL9zKlM9IlImIjeISI6IpIjIFcCNwPNR2l4CPAhcp6qvDzhXICJXiEiGiDhF5CME7gH8aXxeijHGTH99Of6pTfUogbTOEaCFQM7+S6r6uIhUBuvxK4NtbwPygaciavWfDp5LBe4EGoBG4AvA+1TVavmNMSbI4534HP+wqR5VbQAuinGuDsiJ+Dpm6WbwOmeNoo/GGJM0wuWctmSDMcYkB2/45m4ClHMaY4yZeKGZu1Od4zfGGDNJPOGqHgv8xhiTFNyW6jHGmOSSKOWcxhhjJslkzNy1wD/O9pzo4J51B6a6G8aYaWoylmW2wD/OHttyhP98Zjfdbu9Ud8UYMw15LMc//TR1ugFo6fZMcU+MMdORxxsM/A4b8U8bTZ2B1aqbg28AxhgzEpbqmYaaugIBv7nbAr8xZuSsnHMaCqd6uizwG2NGLpzjt1TP9KCqNHUFUz0W+I0xo+D1KU6H4HDYiH9a6Hb76PUE3q1bkjzVs7m2hdt+vx1VnequGDOteHz+Ca3hBwv846op4oZuso/4n9x2jF++Vhu+52GMiY/b55/Q3bfAAv+4auzq238+2Uf8h5u7Aaht6p7inhgzvXh8/gldrgEs8I+r0Ig/MzVl2oz43V4/9e29437d2mDgr2vuGvdrG3Mq83jVUj3TSaiGf0FZzrQJ/A9sqOGCb71AbdP4BWi/X8Mj/rqmnnG7rjHJwOPzk+q0VM+0EcpnLyzLoblreszcPdjYhcvr5+6nd4/bNRs6XbiCsw9rbcRvzIh4/Aky4heRX4nIcRFpF5G9IvLpIdp+WUROiEibiPxURNIjzlWLyAsi0i0iu0XksvF4EYmiqdNNTrqTmfkZtHS7p0VFSyjN8/T2E2w82DQu1wzl9Z0Ooc5y/MaMiMebODn+u4BqVc0D3gPcKSKrBjYSkSuAfwQuBaqBecC/RjR5CNgCFAPfAB4RkdJR9z7BNHW5KM5Joyg7DZ9fae9N/IXaTnb0smZeMbPyM/j3P+7E7x/7m1VdMM2zsrIw/NwYEx9PolT1qOoOVQ2VrGjwMT9K008A9wfbtwD/Dvw1gIgsAlYCt6tqj6o+CrwNXDe2l5A4mjrdFGenUZiVBkyP2bsn211UFWdx61WL2X60nUffPDLma9Y1d+MQWDO/mPoOFz1u3zj
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(lrfind.lrs[:-2],lrfind.losses[:-2])\n",
"plt.xscale('log')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class OneCycle(Callback):\n",
" def __init__(self, base_lr): self.base_lr = base_lr\n",
" def before_fit(self): self.lrs = []\n",
"\n",
" def before_batch(self):\n",
" if not self.model.training: return\n",
" n = len(self.dls.train)\n",
" bn = self.epoch*n + self.num\n",
" mn = self.n_epochs*n\n",
" pct = bn/mn\n",
" pct_start,div_start = 0.25,10\n",
" if pct<pct_start:\n",
" pct /= pct_start\n",
" lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr\n",
" else:\n",
" pct = (pct-pct_start)/(1-pct_start)\n",
" lr = (1-pct)*self.base_lr\n",
" self.opt.lr = lr\n",
" self.lrs.append(lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onecyc = OneCycle(0.1)\n",
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[onecyc])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.fit(8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD7CAYAAACCEpQdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXiU9bn/8fednSQECIQdEtkRZI0soa7VVq2nsiqCLG4o2J7T6q+2PdaeVtu6tPacagHBWtkUcAH3ra1LlbAFWaOIggn7DoGEJZB8f3/MxI7pABPIZLbP67rm0jzPk8n9JTM3D8/c8xlzziEiIrElLtQFiIhI3VPzFxGJQWr+IiIxSM1fRCQGqfmLiMSghFAXEIgmTZq4nJycUJchIhJRVqxYsdc5l+VvX0Q0/5ycHAoKCkJdhohIRDGz4lPt02UfEZEYpOYvIhKD1PxFRGKQmr+ISAxS8xcRiUEBNX8zyzSzhWZWZmbFZjbqFMddZmbvm1mJmRX52Z/j3X/EzNab2RXnWL+IiJyFQM/8JwPlQDNgNDDVzLr5Oa4M+Cvwk1Pcz1xgJdAYuA940cz8zqCKiEjwnLH5m1kaMAy43zlX6pz7GHgVGFP9WOfcMufcbGCTn/vpBPQB/sc5d9Q59xKw1nvfEga+3F3K2+t2hroMEakDgZz5dwIqnHMbfLatBvyd+Z9ON2CTc+5wIPdjZhPMrMDMCvbs2VPDHyVn457nV3HnnBX86tVCKir1OQ8i0SyQ5p8OlFTbVgLUr+HPqtH9OOemO+dynXO5WVm6MhRsKzcfYPXWEnq2bsCM/CLumF3AkfKToS5LRIIkkOZfCmRU25YBHPZzbF3cjwTBjPwi6icn8NztA/j197vx3vrd3DBtCbsPHwt1aSISBIE0/w1Agpl19NnWEyis4c8qBNqZme+Z/tncj9Sy3YeO8ebaHYzIbUNacgLj8nKYPiaXL3eXMmRyPht26e9nkWhzxubvnCsDFgAPmFmamQ0CrgNmVz/WzOLMLAVI9HxpKWaW5L2fDcAq4H+824cAPYCXam85cjaeXbqZk5WOsQOzv952xfnNeP6OgZRXVDJsaj6LvtwbwgpFpLYFOuo5CagH7MYzrjnROVdoZheZWanPcRcDR4E3gbbe/3/XZ/9IIBc4ADwMDHfO6dXcECo/WcmzSzdzWeem5DRJ+8a+C1o3YOGkPFo0SGHcX5fxQsGWEFUpIrUtoEhn59x+YLCf7R/heSG36usPADvN/RQBl9awRgmiN9fuYG/pccbl5fjd37pRKi9OzGPinBX85MU1bNl/hB9f2QmzU/6aRSQCKN4hxs3IL6JdVhoXdWhyymMyUhJ5Znw/RvRtzePvfcndz6/m+MmKOqxSRGqbmn8MW7XlIKu2HGTcwBzi4k5/Jp+UEMejw3twz5WdWLhyG+P+uoySIyfqqFIRqW1q/jFsZn4R6ckJDOvbOqDjzYwffrsj/3dDLz4pPsjQqYvYsv9IkKsUkWBQ849Ruw8f4/U12xnetzXpyTX7NM/BvVsx69Z+7C0tZ8iURazacjBIVYpIsKj5x6i5S7dwouKb4501MaBdY16amEe9pHhGTl+sTCCRCKPmH4M8453FXNo5i3ZZ6Wf+hlPo0DSdhZMG0aV5BhOfXcFfPtqEc8oEEokEav4x6K11O9h9+NTjnTXRJD2ZubcP4LvnN+c3b3ymUDiRCKHmH4Nm5hdxXpM0LulYO4F59ZLimTK6D7dfdB4zFxcrFE4kAqj5x5g1Ww/yyeaDjB2YfcbxzpqIizPu+975PHCdTyjcIYXCiYQrNf8YMyO/iLSkeIYHON5ZU2MH5vDUWG8o3BSFwomEKzX/GLK39Divr97B8L6tqZ+SGLSf8+2uPqFwU/L5+AuFwomEGzX/GDJ36WbKKyoZWwsv9J7JBa0b8PJdg2jRMIXxzyzjeYXCiYQVNf8YcaKikjlLi7m4Uxbtz2G8syZaNazHixPzGNCuMfe+uIbH3v1co6AiYULNP0a8vW4nuw4dZ3ze2b2p62xlpCTyzM0Xcn1ua55470t+PH+VQuFEwkDN3tcvEWtmfhHZjVO5tFPTOv/ZifFxPDKsB20zU/nDuxvYUXKM6WNyaZAavNcdROT0dOYfA9ZtK6Gg+ABjA0jvDBYz4weXd+RPI3uxcvNBhkxdxOZ9CoUTCRU1/xgwI7+I1KR4RuQGZ7yzJq7r1YrZt/ZjnzcUbuXmA6EuSSQmqflHuX2lx3l19XaG9WlNRhDHO2uif7vGLJiUR1pyAiOnL+HtdTtCXZJIzFHzj3Lzlm+h/GQl4+r4hd4zaZ+VzoJJeXRtkcHEZz9RKJxIHVPzj2InKiqZs6SYizo2oUPT+qEu5980SU9m3oQBXNXNEwr3P68WcrKiMtRlicQENf8o9m7hLnaUHGPcwJxQl3JKKYnxTB7VhwkXt2PW4mLumL2CsuMKhRMJNjX/KDYzv4i2malc1qXuxztrIi7O+O9ruvLgdd14//Pd3DB9sULhRIJMzT9KFW4vYVnRfsYOzCY+ROOdNTVmYA5/GZfLpj1lDJ68iM93KhROJFjU/KPUzPwi6iXGMyK3TahLqZHLu3hC4U5WOoZPVSicSLCo+Ueh/WXlvLJqO0P7tKJBvfAY76yJ7q08oXAtG9bzhMItVyicSG1T849C85Zv5vjJylr5mMZQadmwHi9MHMjA9o259yWFwonUNjX/KHOyopI5i4sZ1KExnZqF33hnTWSkJPLX8RdyQ24bnnjvS36kUDiRWqPmH2X+9ukutof5eGdNJMbH8fCwC/jJdzvzyqrtjHl6GQePlIe6LJGIp+YfZWbkF9G6UT2+3bVZqEupNWbGXZd14E8je7Fq80GGTs1XKJzIOVLzjyKf7TjE0q8ia7yzJq7r1Yo5t/X/OhTuE4XCiZy1gJq/mWWa2UIzKzOzYjMbdYrjzMweMbN93tujZmY++y83s0/M7JCZbTKzCbW1EPGMd6YkxnF9hI131kS/8zK/DoW7cfoS3lqrUDiRsxHomf9koBxoBowGpppZNz/HTQAGAz2BHsC1wB0AZpYILASmAQ2AG4A/mlnPc1mAeBwoK+flVdsY0rs1DVOTQl1OULXPSmfhpDzOb5nBpOcUCidyNs7Y/M0sDRgG3O+cK3XOfQy8Cozxc/g44DHn3Fbn3DbgMWC8d18mkAHMdh7Lgc+A8899GTK/YAvHToRfemewNE5PZu7t/wqF++UrCoUTqYlAzvw7ARXOuQ0+21YD/s78u3n3/dtxzrldwFzgZjOLN7OBQDbwsb8famYTzKzAzAr27NkTQJmx62RFJbMXFzOwXWO6NM8IdTl1pioU7o6L2zF7STETFAonErBAmn86UFJtWwngb4i8+rElQLrPdf+5wC+B48BHwH3OOb9v33TOTXfO5TrncrOysgIoM3b9/bPdbDt4NKLf1HW24uKMn1/TlQcHd+eDz3dz/bTF7FIonMgZBdL8S/FcrvGVAfhL3ap+bAZQ6pxzZtYFmA+MBZLw/IvgXjP7Xo2rlm+YmV9Eq4b1uKJreKd3BtOYAdk8Pe5CvtpbxpDJi1i/81CoSxIJa4E0/w1Agpl19NnWEyj0c2yhd5+/47oDnzvn3nHOVTrnPgfeAK6uedlSZf3OQyzetI8xA7NJiI/tyd3LujT9OhRuxNTFfPSFLheKnMoZu4VzrgxYADxgZmlmNgi4Dpjt5/BZwN1m1srMWgL3ADO8+1YCHb3jnmZm7fFMA632cz8SoJn5xSQnxHFDFI931kRVKFyrRvW4+ZnlCoUTOYVATxUnAfWA3Xiu2090zhWa2UVmVupz3DTgNWAtsA7Pmf00AOfcRuAW4HHgEPAh8BLwdC2sIyYdPFLOwpVbGdK7FY3Sonu8syZaNqzHC3f+KxTuD+8oFE6kuoRADnLO7cczv199+0d4XuSt+toB93pv/u7neeD5s6pU/s3zX4935oS6lLBT3xsKd//L6/jz+1+y5cARHh3eg+SE+FCXJhIWAmr+En4qKh2zFhfT/7xMuraInfHOmkiMj+OhoRfQJjOV37/zOTsOHmP62L5
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(onecyc.lrs);"
]
},
2020-04-23 18:24:16 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 23:56:41 +00:00
"> tip: Experiments: For the questions here that ask you to explain what some function or class is, you should also complete your own code experiments."
2020-03-18 00:34:07 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 23:56:41 +00:00
"1. What is `glob`?\n",
2020-03-18 00:34:07 +00:00
"1. How do you open an image with the Python imaging library?\n",
2020-05-19 23:56:41 +00:00
"1. What does `L.map` do?\n",
"1. What does `Self` do?\n",
"1. What is `L.val2idx`?\n",
"1. What methods do you need to implement to create your own `Dataset`?\n",
2020-03-18 00:34:07 +00:00
"1. Why do we call `convert` when we open an image from Imagenette?\n",
"1. What does `~` do? How is it useful for splitting training and validation sets?\n",
2020-05-19 23:56:41 +00:00
"1. Does `~` work with the `L` or `Tensor` classes? What about NumPy arrays, Python lists, or pandas DataFrames?\n",
"1. What is `ProcessPoolExecutor`?\n",
2020-03-18 00:34:07 +00:00
"1. How does `L.range(self.ds)` work?\n",
"1. What is `__iter__`?\n",
"1. What is `first`?\n",
"1. What is `permute`? Why is it needed?\n",
"1. What is a recursive function? How does it help us define the `parameters` method?\n",
2020-05-19 23:56:41 +00:00
"1. Write a recursive function that returns the first 20 items of the Fibonacci sequence.\n",
2020-03-18 00:34:07 +00:00
"1. What is `super`?\n",
2020-05-19 23:56:41 +00:00
"1. Why do subclasses of `Module` need to override `forward` instead of defining `__call__`?\n",
"1. In `ConvLayer`, why does `init` depend on `act`?\n",
2020-03-18 00:34:07 +00:00
"1. Why does `Sequential` need to call `register_modules`?\n",
2020-05-19 23:56:41 +00:00
"1. Write a hook that prints the shape of every layer's activations.\n",
"1. What is \"LogSumExp\"?\n",
"1. Why is `log_softmax` useful?\n",
"1. What is `GetAttr`? How is it helpful for callbacks?\n",
2020-03-18 00:34:07 +00:00
"1. Reimplement one of the callbacks in this chapter without inheriting from `Callback` or `GetAttr`.\n",
"1. What does `Learner.__call__` do?\n",
"1. What is `getattr`? (Note the case difference to `GetAttr`!)\n",
"1. Why is there a `try` block in `fit`?\n",
"1. Why do we check for `model.training` in `one_batch`?\n",
"1. What is `store_attr`?\n",
"1. What is the purpose of `TrackResults.before_epoch`?\n",
2020-05-19 23:56:41 +00:00
"1. What does `model.cuda` do? How does it work?\n",
2020-03-18 00:34:07 +00:00
"1. Why do we need to check `model.training` in `LRFinder` and `OneCycle`?\n",
"1. Use cosine annealing in `OneCycle`."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 23:56:41 +00:00
"1. Write `resnet18` from scratch (refer to <<chapter_resnet>> as needed), and train it with the `Learner` in this chapter.\n",
"1. Implement a batchnorm layer from scratch and use it in your `resnet18`.\n",
"1. Write a Mixup callback for use in this chapter.\n",
"1. Add momentum to SGD.\n",
2020-03-18 00:34:07 +00:00
"1. Pick a few features that you're interested in from fastai (or any other library) and implement them in this chapter.\n",
"1. Pick a research paper that's not yet implemented in fastai or PyTorch and implement it in this chapter.\n",
" - Port it over to fastai.\n",
2020-05-19 23:56:41 +00:00
" - Submit a pull request to fastai, or create your own extension module and release it. \n",
" - Hint: you may find it helpful to use [`nbdev`](https://nbdev.fast.ai/) to create and deploy your package."
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}