2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-18 21:18:08 +00:00
"# ResNets"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Going Back to Imagenette"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_data(url, presize, resize):\n",
" path = untar_data(url)\n",
" return DataBlock(\n",
" blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, \n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_y=parent_label, item_tfms=Resize(presize),\n",
" batch_tfms=[*aug_transforms(min_scale=0.5, size=resize),\n",
" Normalize.from_stats(*imagenet_stats)],\n",
" ).dataloaders(path, bs=128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = get_data(URLs.IMAGENETTE_160, 160, 128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAFkCAYAAACKFkioAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZBtx33f9+nus919m33evP09PKzEQpAgSEqidtlW4khx4rIqkpJyllLJVuJYiuWkYmVxnDjlVGK7HKdSSWTFUSxbsWItJYmhKFEkAZA0Nj4AfHj7MjNv1jv3zt3O1t35o888DB8BiBDxBBB1v1Wn5s7pe/r08utf//q3XWGtZYoppphiinsD+V43YIopppjig4wpk51iiimmuIeYMtkppphiinuIKZOdYooppriHmDLZKaaYYop7iCmTnWKKKaa4h5gy2SmmmGKKe4gPNJMVQnyPEOKCEGIshPh9IcSxQ2V/WwhxSwixL4S4IYT4T9+ijp8QQlghxF88dO9nhRCvCCEGQohrQoifveuZ60KIiRBiWFyffou6P1vU7R26918JIc4LIXIhxC/c9f1PFWU9IcSuEOLXhBDLf8zhmeLbGN8KbQshHhVCPF88+7wQ4tFDZX8UbT8qhPi8EKIvhFgVQvzn76Bdrx5aE8OCxn/jUPkPF+8eCiGeEUI88G6O2XsGa+0H8gJmgD7w54AI+O+B5w6V3wdUis/LwKvAj9xVRwu4ALwC/MVD938OeBzwinpuAH/+UPl14Hv/iPb9GPCHgAW8Q/d/Avgh4F8Av3DXM/PAUvE5BP428Ovv9VhPrz/Z61uhbSAo6PU/KmjoLxf/B0X5H0XbrwF/E1DAKeA28K98M+26qw8CuAr8ePH/GWAf+ETx7p8HLh9eG9+u13vegHeB4K4DfxX4ajHBv1JM8L8HPHPoexVgApx7kzqWgfPAz911/x8CPwX8wWEm+ybP/13g793VprdkskADuAg8dTeTPfSdf3w3k72rPAT+FvDaez0H0+veXPeCtoHvB9YAceg7N4EffIs23E3bY+CBQ///M+Dni8/vpF3fCQx5YzP4aeC3DpXL4tnvea/n4Vu9Pijqgn8D+EHgBPAI8JPAg8DLB1+w1o6AK8V9AIQQf00IMQRWcQTxy4fKPgJ8GMdo3xJCCAF8EictHMb/JYTYFkJ8WgjxobvK/hvgfwY2vvku3nnfUSFED0eAfxUnzU7xwcW7TdsPAl+1BScr8NXDzx6q481o+38EflwI4Qsh7gM+BnzmUN1v265D+AngV4vvgJNsxeHXF9dDb/LstxU+KEz271pr1621XeA3gEeBKm73P4w+UDv4x1r73xb/Pw78nwffF0Io4B8Af8laa/6Id/8Cbhz/j0P3fgw4DhwDfh/4XSFEs6j7w8DHgb/3TjtZtPmmtbaJO5r9Zzh1xhQfXLyrtP3NPHsIv8A30vZvAv86bpO/APxv1tqvvJO6hRDloo5fPHT7/wO+UwjxXUKIAPjrONVG+U3a9W2FDwqTPSwRjnGTPQTqd32vDgwO37AOL+KI5r8obv8Ubrd/9u1eKoT4aeDHgT9trU0O1flFa+3EWju21v4toAd8Ugghccz7Z6y1+Tvt5F3t7gL/CPgXhw1nU3zg8G7T9jf17JvRthCiDfwO8F/i1BYrwA8IIX7qndQN/AjQBT53qK0XcNLt38fpeWdw+t9Vvs3xQWGyb4ZXgTvHdCFEBaeov/tYfwCvKAf4HuBfE0JsCCE2gKeBvyOE+PuH6vt3gL+G0xn9UYRgcUefOk4F8StFvQcSwKoQ4pPvpHOH2jzHNxL2FB9sfCu0/SrwSKEKOMAjh599G9o+CWhr7S9Za/Oi7J8Af+odtusngF+6S2WBtfZXrbUPWWs7wN/AnQS/wrc73mul8Ld6cZeRCXfE+cfALO6o8qO4Xfe/o7B04jaXfx/nPSCAj+B2z79clDeBhUPXM8BfARpF+Y/hJIz736Q9R3HqgKB4788C20CneNfhep/EMeBl3rDu+sVzvwz818VnVZT9CM7iK4v+/VPghfd6DqbXtxVtH3gX/AzOePrTfL13wdvRdh13KvsLxXsWgGeBv1mUv2W7DtVxBMiBU29S/xM4r4VZnJHvl9/rOXhX5vG9bsC9IsTi8/fi9EYTnIfA8UOE+Du4I8sQZ+n/6xyyuN71jj/g6124rgFZ8ezB9Q+LsgdxhoQRsAv8HvDht6j3ON/owvWLxb3D108WZX+pePeoWAj/BDj2Xs/B9Pr2om3gMeD54tkXgMcOlb0lbRfl342TLvsFDf6vQPlQ+Zu261D5zwOff4v+fgGnWugC/wuF58G3+yWKzk0xxRRTTHEP8EHWyU4xxRRTvOeYMtkppphiinuIKZOdYooppriHmDLZKaaYYop7iCmTnWKKKaa4h3jbSKEHHv+oBYMwGaIIUFJCIKUELHmWMR6PGQxcQMdoNCZNUrS23O2zIITA2gP/5wPPpHcfotg3Dv4aYUFZwCCLANnQwGNNxb+9fIQfrrSYkTDOxgCM0zHVWolypYQVAmMl2gau714ZIX2GNmesoDpTp7wyg5x3sQCT3g5rN2+S9Ac0LJQtRLLYx0ROLiZkQYKuWGzVw2v6AESdgKBjsPUBojZElDKk58bbWIswEoFEG0uWZxwEi/m+wPcUWI8s9en1FGlSK8a7zVcvjPj8l9f5yvkh125buiM35pNckBnlIiSERgiLLKbGWjAGrJXgVfGEhXzoygDPC0GEpKnGYBDkxftypBUoAkAhRIwvD2hGYYxE+AG9eHjYCf49w0o1shmQaYPJHWEoJMoKBBapoNmq8+gjLnT+I088yrmzZ5ibbTM/N4OPQacTAPo72/S7O2TxmDyNCXwPody8JyZnNJkw7o+oVjvkssSF1dsAbOSaxpGj/PI//X/wEsOJ1iwbly4D8OTjj9Kea2GUxyQWaO2ztLQIgO8phqMxWnikWrJ2fZX+2hoAdTNBTfbo9bscObJA4oWoShUAqwJSEzGzdAoV1tFGkOeuD+mkiy+HZEmX4bBHWK5za22Xestl0hxsrrNQExw91mLp+CLDScz1G13Xj9spt7f3iXVOfabG4soSO7uOZnLqbNze5dSxEvFgg9s3V+k0m8V4W9A5GxtrZFmCEYa84AtaWJACqSxRYDjSrnJu9igAVRFy8bWv0dcpmVKEYYmKVwJgYf4IslLmM19+DhtJzh6dZzLUrp09w9qgjyoH3Hf/OXbXRszVFgA4ujzL0pEZqrUWzeYx9vYG+KHjCf29Dfp7Y/a6I3KZMcy2ePVrLr5iMkjolKucPVnit1+4/Ka0/bZM1vMkWJDKw5rieaOxgBQSP4ioSIVSjlmEYcR4PGY0mpBlGcbczUgP/r+7Le8ew7V3/UXcfcPdapV9FoOAyGqENpA4YisLSwBoq8mFxQqDwUXMZnmCECEGH2E9SATECkwFgMjPmKk0iXNBKU0JkxwvcYxGCoWRZbJMYbIMkVpE7hplE00y9PDaNWTLh+YIXXUTbIIUKzQCiVIR0npo7dqjjUZag1IJQSlhrizBFNG9+T6Liw2e/vAKr70+4bkX93nhVZeL4/Jazq0tzd7IklqFFZBbXYyTyxokMdh8gCFACrfJGJsS5zGIBKQPJsCNFlgERubkpDiy8ohtseFpg8BCMv5WpvZdRZ5pMgG5tVAEP1njNhlPOEbWbpR5/PFHAPjI008xPztHrVqlWq2wuXaLSxccQ/zcp3+beqXE/GyHvc0NyqWAzuyMe5FnWN+8ya0bNzhy7AG+44d+lB0yAF744jPsJ2NOLbSJuwMmuxssdhoApPt9+ianMb9Eqqr0RERvt8ilkg+pVSOiMGJrfZM83uP+420ATjZD2mHOZNxFhYLS3CyDYpnvJB5X1nps7a9Ra1lG45Sg4ADlqmTY6yEYMrNQZXt7G8sQP5gDoFKSrHRqHGk36d2+xcvnz7M3cOMm/SXmmrM0OnMM0yG9nX3i1NHT8XMrzK2cQGZ98smYs/c/QrPhBIEXvvwcWTrBeh5SGHSeODoBPGFdhpjc4kmLjTPG+06YO35
"text/plain": [
"<Figure size 432x432 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch(max_n=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def avg_pool(x): return x.mean((2,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def block(ni, nf): return ConvLayer(ni, nf, stride=2)\n",
"def get_model():\n",
" return nn.Sequential(\n",
" block(3, 16),\n",
" block(16, 32),\n",
" block(32, 64),\n",
" block(64, 128),\n",
" block(128, 256),\n",
" nn.AdaptiveAvgPool2d(1),\n",
" Flatten(),\n",
" nn.Linear(256, dls.c))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_learner(m):\n",
" return Learner(dls, m, loss_func=nn.CrossEntropyLoss(), metrics=accuracy\n",
" ).to_fp16()\n",
"\n",
"learn = get_learner(get_model())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.47863011360168456, 3.981071710586548)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEKCAYAAAAIO8L1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXyU5bn/8c+VnSyEJUEwEMImiAgCEUXUYutWa4u22qrVaquH1rZWWtvT1p7TnqPn2Hq6t9YqVWtrtYt1+VnrRq07ggYEWaLIKmFNwpKErJNcvz9mgBiHACVP5pnk+3695pWZZ5tvhpAr930/z3ObuyMiItJRSqIDiIhIOKlAiIhIXCoQIiISlwqEiIjEpQIhIiJxqUCIiEhcaYkO0JUKCgq8pKQk0TFERJLGokWLqty9MN66HlUgSkpKKCsrS3QMEZGkYWYbDrROXUwiIhKXCoSIiMSlAiEiInGpQIiISFwqECIiEpcKhIiIxKUCISKSxJZv2s38NVUEMXWDCoSISBK7d/56vvLHNwI5tgqEiEgSW75pNxOK8jGzLj+2CoSISJJqaG7lne11TCzKD+T4KhAiIklq5ZYaWtucCSoQIiLS3vJNuwE4fmiSFQgzG2Zmz5lZuZmtMLPr42wz08x2m9mS2OO77datN7NlseW6A5+ISAfLNu2mIDeDwX2zAjl+kHdzjQA3uPtiM8sDFpnZPHdf2WG7l9z9/AMc4wx3rwowo4hI0lpWEdwANQTYgnD3Le6+OPa8FigHioJ6PxGR3iQ6QF0b2AA1dNMYhJmVAJOBhXFWTzezpWb2pJkd1265A8+Y2SIzm93JsWebWZmZlVVWVnZpbhGRsFq5pYY2J7ABauiGCYPMLBd4CJjj7jUdVi8Ghrt7nZmdBzwKjImtm+Hum81sEDDPzN5y9xc7Ht/d5wJzAUpLS7v+UkIRkRAKeoAaAm5BmFk60eJwv7s/3HG9u9e4e13s+RNAupkVxF5vjn3dDjwCTAsyq4hIMnmzItgBagj2LCYD7gbK3f0nB9hmcGw7zGxaLE+1meXEBrYxsxzgbGB5UFlFRJLN8k27OT7AAWoItotpBnAFsMzMlsSW3QgUA7j7HcBFwLVmFgEagEvc3c3sKOCR2DeeBjzg7k8FmFVEJGnsHaA+57ijAn2fwAqEu78MdFra3P024LY4y9cCkwKKJiKS1LpjgBp0JbWISNJZVrELCHaAGlQgRESSzrJNNYEPUIMKhIhI0lmycWfgA9SgAiEiklQ2VO9hTeUeThtTGPh7qUCIiCSRf5RvB+DMY4M9gwlUIEREksqz5dsYMyiX4oHZgb+XCoSISJKoaWzhtXU7+FA3tB5ABUJEJGm88HYlkTbnrPGDuuX9VCBERJLEP8q3MSAngxOG9e+W91OBEBFJApHWNp5/u5Izxg4iNSXY01v3UoEQEUkCZRt2sruhhTOP7Z7uJVCBEBFJCs+WbyMjNYXTjgn++oe9VCBERELO3flH+XZOHjWQ3MzA53nbRwVCRCTEGltamfPnJayr2sP5xw/p1vfuvlIkIiKHZXttI7N/v4glG3fxjXPGcnHp0G59fxUIEZEQqqxt4oLbXmFnfQt3XD6Fcyd0b+sBgp1ydJiZPWdm5Wa2wsyuj7PNTDPbbWZLYo/vtlt3rpm9bWarzexbQeUUEQmjRRt2sHl3I79OUHGAYFsQEeAGd18cm196kZnNc/eVHbZ7yd3Pb7/AzFKBXwFnARXA62b2WJx9RUR6pMaWNgCKBwR/z6UDCawF4e5b3H1x7HktUA4UHeLu04DV7r7W3ZuBPwGzgkkqIhI+TZFWADLTUxOWoVvOYjKzEmAysDDO6ulmttTMnjSz42LLioCN7bap4NCLi4hI0muKRFsQWWmJO9k08EFqM8sFHgLmuHtNh9WLgeHuXmdm5wGPAmOAeNeR+wGOPxuYDVBcXNxluUVEEqkp1sXUY1sQZpZOtDjc7+4Pd1zv7jXuXhd7/gSQbmYFRFsMw9ptOhTYHO893H2uu5e6e2lhYfddYSgiEqR9XUwJbEEEeRaTAXcD5e7+kwNsMzi2HWY2LZanGngdGGNmI8wsA7gEeCyorCIiYdMUaSPFIK2bbswXT5BdTDOAK4BlZrYktuxGoBjA3e8ALgKuNbMI0ABc4u4ORMzsy8DTQCpwj7uvCDCriEioNLa0kpmWSuxv6IQIrEC4+8vEH0tov81twG0HWPcE8EQA0UREQq8p0kZmemLvhqR7MYmIhFBTSxtZaYkboAYVCBGRUGqKtKoFISIi79cUaUvoGUygAiEiEkrRAqEuJhER6aAp0qoWhIiIvF9ji85iEhGROKItCHUxiYhIB00tbWSpBSEiIh1pkFpEROLSILWIiMSl6yBERCSuxpbWhM4FASoQIiKh4+5qQYiIyPu1tDrukKUWhIiItBeG2eRABUJEJHSaIrH5qHtqgTCzYWb2nJmVm9kKM7u+k21PNLNWM7uo3bJWM1sSe2i6URHpNfYXiMR2MQU55WgEuMHdF5tZHrDIzOa5+8r2G5lZKnAr0elF22tw9xMCzCciEkpNLbEupp56JbW7b3H3xbHntUA5UBRn0+uAh4DtQWUREUkmjS09vIupPTMrASYDCzssLwIuBO6Is1uWmZWZ2QIzuyDwkCIiIbF/kLrndjEBYGa5RFsIc9y9psPqnwHfdPdWM+u4a7G7bzazkcA/zWyZu6+Jc/zZwGyA4uLirv8GRES62b4xiJ7axQRgZulEi8P97v5wnE1KgT+Z2XrgIuD2va0Fd98c+7oWeJ5oC+R93H2uu5e6e2lhYWHXfxMiIt0sLIPUQZ7FZMDdQLm7/yTeNu4+wt1L3L0E+CvwRXd/1Mz6m1lm7DgFwAxgZbxjiIj0NPsGqRM8BhFkF9MM4ApgmZktiS27ESgGcPd44w57HQvcaWZtRIvYDzqe/SQi0lPtbUEkej6IwAqEu78MvG9goZPtr2r3fD5wfACxRERCr7ElHIPUupJaRCRkevyV1CIi8q/ZfxaTWhAiItKObtYnIiJxNfWmK6lFROTQNUXayEhLIc4FxN1KBUJEJGQaW1oT3noAFQgRkdCJTjea2AFqUIEQEQmdpkhrwi+SAxUIEZHQibYgEv/rOfEJRETkPZpa1MUkIiJxNEVaE36rb1CBEBEJHXUxiYhIXE0trepiEhGR91MLQkRE4mqKtJGV4Bv1gQqEiEjoNOlKahERiacp0tazz2Iys2Fm9pyZlZvZCjO7vpNtTzSzVjO7qN2yK83sndjjyqByioiETVhutRHknNQR4AZ3X2xmecAiM5vXcW5pM0sFbgWebrdsAPA9oBTw2L6PufvOAPOKiIRCj79Zn7tvcffFsee1QDlQFGfT64CHgO3tlp0DzHP3HbGiMA84N6isIiJhEWltI9LmoWhBdEuJMrMSYDKwsMPyIuBC4I4OuxQBG9u9riB+ccHMZptZmZmVVVZWdlVkEZGEaG6NThbUK27WZ2a5RFsIc9y9psPqnwHfdPfWjrvFOZTHO767z3X3UncvLSwsPPLAIiIJFJbZ5CDYMQjMLJ1ocbjf3R+Os0kp8KfYrEkFwHlmFiHaYpjZbruhwPNBZhURCYOmSKxAhOA6iMAKhEV/698NlLv7T+Jt4+4j2m1/L/C4uz8aG6S+xcz6x1afDXw7qKwiImHRFIl2qPT0FsQM4ApgmZktiS27ESgGcPeO4w77uPsOM7sZeD226CZ33xFgVhGRUGjc18XUg1sQ7v4y8ccSDrT9VR1e3wPc08WxRERCLUwtiMQnEBGRffaPQST+1/MhJTCzUWaWGXs+08y+Ymb9go0mItL77D2LKZlu1vcQ0Gpmo4kOPI8AHggslYhIL5WMXUxt7h4helHbz9z9q8CQ4GKJiPRO+7qYQjBIfagFosXMLgWuBB6PLUsPJpKISO+VjC2IzwLTgf9193VmNgL4Q3CxRER6p32nuYZgkPqQTnON3YH1KwCxi9fy3P0HQQYTEemNmlr2tiCSpIvJzJ43s76xK5yXAr81s7hXR4uIyL9u7xhEMt2sLz92o72PA79196n
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.901582</td>\n",
" <td>2.155090</td>\n",
" <td>0.325350</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.559855</td>\n",
" <td>1.586795</td>\n",
" <td>0.507771</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.296350</td>\n",
" <td>1.295499</td>\n",
" <td>0.571720</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.144139</td>\n",
" <td>1.139257</td>\n",
" <td>0.639236</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.049770</td>\n",
" <td>1.092619</td>\n",
" <td>0.659108</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Building a Modern CNN: ResNet"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-18 21:18:08 +00:00
"### Skip Connections"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ResBlock(Module):\n",
" def __init__(self, ni, nf):\n",
" self.convs = nn.Sequential(\n",
" ConvLayer(ni,nf),\n",
" ConvLayer(nf,nf, norm_type=NormType.BatchZero))\n",
" \n",
" def forward(self, x): return x + self.convs(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _conv_block(ni,nf,stride):\n",
" return nn.Sequential(\n",
" ConvLayer(ni, nf, stride=stride),\n",
" ConvLayer(nf, nf, act_cls=None, norm_type=NormType.BatchZero))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ResBlock(Module):\n",
" def __init__(self, ni, nf, stride=1):\n",
" self.convs = _conv_block(ni,nf,stride)\n",
" self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act_cls=None)\n",
" self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
"\n",
" def forward(self, x):\n",
" return F.relu(self.convs(x) + self.idconv(self.pool(x)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def block(ni,nf): return ResBlock(ni, nf, stride=2)\n",
"learn = get_learner(get_model())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.973174</td>\n",
" <td>1.845491</td>\n",
" <td>0.373248</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.678627</td>\n",
" <td>1.778713</td>\n",
" <td>0.439236</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.386163</td>\n",
" <td>1.596503</td>\n",
" <td>0.507261</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.177839</td>\n",
" <td>1.102993</td>\n",
" <td>0.644841</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.052435</td>\n",
" <td>1.038013</td>\n",
" <td>0.667771</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 3e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def block(ni, nf):\n",
" return nn.Sequential(ResBlock(ni, nf, stride=2), ResBlock(nf, nf))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.964076</td>\n",
" <td>1.864578</td>\n",
" <td>0.355159</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.636880</td>\n",
" <td>1.596789</td>\n",
" <td>0.502675</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.335378</td>\n",
" <td>1.304472</td>\n",
" <td>0.588535</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.089160</td>\n",
" <td>1.065063</td>\n",
" <td>0.663185</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.942904</td>\n",
" <td>0.963589</td>\n",
" <td>0.692739</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(get_model())\n",
"learn.fit_one_cycle(5, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### A State-of-the-Art ResNet"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _resnet_stem(*sizes):\n",
" return [\n",
" ConvLayer(sizes[i], sizes[i+1], 3, stride = 2 if i==0 else 1)\n",
" for i in range(len(sizes)-1)\n",
" ] + [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[ConvLayer(\n",
" (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" ), ConvLayer(\n",
" (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" ), ConvLayer(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" ), MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_resnet_stem(3,32,32,64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ResNet(nn.Sequential):\n",
" def __init__(self, n_out, layers, expansion=1):\n",
" stem = _resnet_stem(3,32,32,64)\n",
" self.block_szs = [64, 64, 128, 256, 512]\n",
" for i in range(1,5): self.block_szs[i] *= expansion\n",
" blocks = [self._make_layer(*o) for o in enumerate(layers)]\n",
" super().__init__(*stem, *blocks,\n",
" nn.AdaptiveAvgPool2d(1), Flatten(),\n",
" nn.Linear(self.block_szs[-1], n_out))\n",
" \n",
" def _make_layer(self, idx, n_layers):\n",
" stride = 1 if idx==0 else 2\n",
" ch_in,ch_out = self.block_szs[idx:idx+2]\n",
" return nn.Sequential(*[\n",
" ResBlock(ch_in if i==0 else ch_out, ch_out, stride if i==0 else 1)\n",
" for i in range(n_layers)\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rn = ResNet(dls.c, [2,2,2,2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.673882</td>\n",
" <td>1.828394</td>\n",
" <td>0.413758</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.331675</td>\n",
" <td>1.572685</td>\n",
" <td>0.518217</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.087224</td>\n",
" <td>1.086102</td>\n",
" <td>0.650701</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.900428</td>\n",
" <td>0.968219</td>\n",
" <td>0.684331</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.760280</td>\n",
" <td>0.782558</td>\n",
" <td>0.757197</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(rn)\n",
"learn.fit_one_cycle(5, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Bottleneck Layers"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _conv_block(ni,nf,stride):\n",
" return nn.Sequential(\n",
" ConvLayer(ni, nf//4, 1),\n",
" ConvLayer(nf//4, nf//4, stride=stride), \n",
" ConvLayer(nf//4, nf, 1, act_cls=None, norm_type=NormType.BatchZero))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = get_data(URLs.IMAGENETTE_320, presize=320, resize=224)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rn = ResNet(dls.c, [3,4,6,3], 4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.613448</td>\n",
" <td>1.473355</td>\n",
" <td>0.514140</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.359604</td>\n",
" <td>2.050794</td>\n",
" <td>0.397452</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.253112</td>\n",
" <td>4.511735</td>\n",
" <td>0.387006</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.133450</td>\n",
" <td>2.575221</td>\n",
" <td>0.396178</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.054752</td>\n",
" <td>1.264525</td>\n",
" <td>0.613758</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.927930</td>\n",
" <td>2.670484</td>\n",
" <td>0.422675</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.838268</td>\n",
" <td>1.724588</td>\n",
" <td>0.528662</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.748289</td>\n",
" <td>1.180668</td>\n",
" <td>0.666497</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.688637</td>\n",
" <td>1.245039</td>\n",
" <td>0.650446</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.645530</td>\n",
" <td>1.053691</td>\n",
" <td>0.674904</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.593401</td>\n",
" <td>1.180786</td>\n",
" <td>0.676433</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>0.536634</td>\n",
" <td>0.879937</td>\n",
" <td>0.713885</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>0.479208</td>\n",
" <td>0.798356</td>\n",
" <td>0.741656</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>0.440071</td>\n",
" <td>0.600644</td>\n",
" <td>0.806879</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>0.402952</td>\n",
" <td>0.450296</td>\n",
" <td>0.858599</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>0.359117</td>\n",
" <td>0.486126</td>\n",
" <td>0.846369</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>0.313642</td>\n",
" <td>0.442215</td>\n",
" <td>0.861911</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>0.294050</td>\n",
" <td>0.485967</td>\n",
" <td>0.853503</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>0.270583</td>\n",
" <td>0.408566</td>\n",
" <td>0.875924</td>\n",
" <td>00:32</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>0.266003</td>\n",
" <td>0.411752</td>\n",
" <td>0.872611</td>\n",
" <td>00:33</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(rn)\n",
"learn.fit_one_cycle(20, 3e-3)"
]
},
2020-04-23 18:24:16 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-18 21:18:08 +00:00
"1. How did we get to a single vector of activations in the CNNs used for MNIST in previous chapters? Why isn't that suitable for Imagenette?\n",
2020-03-18 00:34:07 +00:00
"1. What do we do for Imagenette instead?\n",
2020-05-18 21:18:08 +00:00
"1. What is \"adaptive pooling\"?\n",
"1. What is \"average pooling\"?\n",
2020-03-18 00:34:07 +00:00
"1. Why do we need `Flatten` after an adaptive average pooling layer?\n",
2020-05-18 21:18:08 +00:00
"1. What is a \"skip connection\"?\n",
2020-03-18 00:34:07 +00:00
"1. Why do skip connections allow us to train deeper models?\n",
"1. What does <<resnet_depth>> show? How did that lead to the idea of skip connections?\n",
2020-05-18 21:18:08 +00:00
"1. What is \"identity mapping\"?\n",
"1. What is the basic equation for a ResNet block (ignoring batchnorm and ReLU layers)?\n",
"1. What do ResNets have to do with residuals?\n",
"1. How do we deal with the skip connection when there is a stride-2 convolution? How about when the number of filters changes?\n",
"1. How can we express a 1\\*1 convolution in terms of a vector dot product?\n",
"1. Create a `1x1 convolution` with `F.conv2d` or `nn.Conv2d` and apply it to an image. What happens to the `shape` of the image?\n",
2020-03-18 00:34:07 +00:00
"1. What does the `noop` function return?\n",
"1. Explain what is shown in <<resnet_surface>>.\n",
"1. When is top-5 accuracy a better metric than top-1 accuracy?\n",
2020-05-18 21:18:08 +00:00
"1. What is the \"stem\" of a CNN?\n",
"1. Why do we use plain convolutions in the CNN stem, instead of ResNet blocks?\n",
2020-03-18 00:34:07 +00:00
"1. How does a bottleneck block differ from a plain ResNet block?\n",
"1. Why is a bottleneck block faster?\n",
2020-05-18 21:18:08 +00:00
"1. How do fully convolutional nets (and nets with adaptive pooling in general) allow for progressive resizing?"
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-18 21:18:08 +00:00
"1. Try creating a fully convolutional net with adaptive average pooling for MNIST (note that you'll need fewer stride-2 layers). How does it compare to a network without such a pooling layer?\n",
"1. In <<chapter_foundations>> we introduce *Einstein summation notation*. Skip ahead to see how this works, and then write an implementation of the 1\\*1 convolution operation using `torch.einsum`. Compare it to the same operation using `torch.conv2d`.\n",
"1. Write a \"top-5 accuracy\" function using plain PyTorch or plain Python.\n",
2020-03-18 00:34:07 +00:00
"1. Train a model on Imagenette for more epochs, with and without label smoothing. Take a look at the Imagenette leaderboards and see how close you can get to the best results shown. Read the linked pages describing the leading approaches."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2020-03-17 19:15:55 +00:00
"source": []
2020-03-06 18:19:03 +00:00
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
2020-03-17 19:15:55 +00:00
"nbformat_minor": 4
2020-03-06 18:19:03 +00:00
}