2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Image classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## From dogs and cats, to pet breeds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.vision.all import *\n",
"path = untar_data(URLs.PETS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#3) [Path('annotations'),Path('images'),Path('models')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#7394) [Path('images/great_pyrenees_173.jpg'),Path('images/wheaten_terrier_46.jpg'),Path('images/Ragdoll_262.jpg'),Path('images/german_shorthaired_3.jpg'),Path('images/american_bulldog_196.jpg'),Path('images/boxer_188.jpg'),Path('images/staffordshire_bull_terrier_173.jpg'),Path('images/basset_hound_71.jpg'),Path('images/staffordshire_bull_terrier_37.jpg'),Path('images/yorkshire_terrier_18.jpg')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(path/\"images\").ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fname = (path/\"images\").ls()[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['great_pyrenees']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"re.findall(r'(.+)_\\d+.jpg$', fname.name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pets = DataBlock(blocks = (ImageBlock, CategoryBlock),\n",
" get_items=get_image_files, \n",
" splitter=RandomSplitter(seed=42),\n",
" get_y=using_attr(RegexLabeller(r'(.+)_\\d+.jpg$'), 'name'),\n",
" item_tfms=Resize(460),\n",
" batch_tfms=aug_transforms(size=224, min_scale=0.75))\n",
"dls = pets.dataloaders(path/\"images\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Presizing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADZCAYAAACttwAaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy8y45tSXKm95m5+7rsS0ScS2ZW1oXVZDcFsAEJggT0sIca6lUkQI+gsQZ6Az2DRprpGYQWGpIIkSLFqqzMc4uIfVvLL2Ya+IqTRYiVINVQswkcOziZESd27O1rLXf7zX77zcTd+WJf7It9sS/2xb7YT5v+Yy/gi32xL/bFvtgX+6dgXwDzi32xL/bFvtgX+3vYF8D8Yl/si32xL/bF/h72BTC/2Bf7Yl/si32xv4d9Acwv9sW+2Bf7Yl/s72FfAPOLfbEv9sW+2Bf7e1j8qR/+D//tf+P/+5//W5b33/HVt19xOBy4H0Z+8eufMf/8Wy5DxnYT/+tv37PkE+Pdnsf3Jy7LQpoTHx5/w3rLHPZ3hOgUb/zi1a/5yw//Bz+/+yP+5tNfMOkeJJIm5c3hT7jZhdfHb8jlTPKK2TOPa6aujevlzNP1E68Pd8zpyJvDa/bjzH/8+oH2v/3A+3ePXC830jhwO9+wNLD6wMfv33M7vWf/+o4YFPfGfr8DVWppNAprbqThgf1xz6tvv+aP/+V/xG4cSWFgGndYa6QpEmMAu7A+P/Pdn/8Fp+dnWl3xMFKK02rBYkSngfW6cH73gVED6/VMiE4ahJqFtLsjHh5IQ+Ttr75lHCaGuhDViQmkNdqaybVwu15oGknTK8adIm7U24XaKp6OSNqBDogXpGVEKs/PZ663wuN33zONyu7NK2658t1vfuDTp/dcLld+8zcfeff0A9UD5sI8KTUL5zOYCEGEV4fEv/jn36Lu5NU4XwpP5yvNV0CIAudbxnLlX//nP2f36humecc8Jva7HcfDHfu7I4eHe6bdwHrL3P3in3N9vDL6wuFuYjzuGOYDImDrDdXA/ttfI2kEBLb/qgSQAAjmDQFiOiAqqAoGWKngDkEQFRABUVwC7lCbcrsVnh5PfPj+A8v5hCo0c/7N//w/cXt+JNrCEAphUEquUFaGYExjZJxHDm8fkBhRCZgpZvBf/tf/vfz/fFb/ne1//O/+K68142aIBFozcm4A/d5tHWY1Z8wMl37/6ppBhZgGhnFknEameSTGiLdKWVdu1xs/vPvI7bawLiu1VppVci6suVCLYeb9/Wojl8y6Zqo7GgJpCExTYp4nYhqIkghqBIEgQlRh3s+M40CIAccxM5ZbRiQw7vbs7+/Z3R2ZppmgSiuNkiu1ZPL13DeRCCFGxnli2s+kIeEt4zWTrzdKLrRmiCgpjYyHHcP+iA4jWN9fUldUIcSIqNBqY11WLqcrt9MzNa+IOBaUVipuDQ2CpoFmEWsV8Yaoc75e+PDhI+8+fOLTp2eu2bll639rY6nG2hqtNlQgqBJVCaqoCCKgCIogoqiAqiCq4IA7YkaKQlChuWOt31dVJSlMUbk/jszzwDwN7OaJ+/s7/tmf/gtSGklDIkUlBhAFcUcFYooM48j+7sgwJVQDsu0ZVUGD9nvkTgiCu+DmeDVwA4SQBsI4kuYdmibitENCQmNCNKFhQDQiEhAJuBfMKlgD0f7vqt1LCC/uAhBEUz/7LjiAf74lNAMzKKVxuyzUUqk5k5cr6+XMf/av/4u/8zz/JGBO+8j9bmJh5npZuGXj/mcPlOK8Puy4Cfh45Hr7CxYyy6XyV+/+grcPv6DkhXw7M3HPLu0510/cja/ZjQeO8x0pjKhHlEBpK7HOuGX2cWBgwJoRh0itmUkPfKw/cF3OzOOOXz68Zh8H7pMyp5EhV64CMgxIG3h8fkQUxnHm4XAkTHseH+8YYuG2BmICE6PVSmlQMmSNEJTh/sichF1whtaYBogpweEOFaFd37M+f+D5d+94evcD61J4/vgBYkJVsbYQ54lwGbk+FdbLlemYKKXw9HTrmz1F3r79Jftvfs35h+9oq6FTAlWez2eSZMQqMSrVAsUStVXSobFeM+qFfM3UBnEuSP7EcHcAV5Yl0/KN9XKhn56B02nB4jNNIrdLZl0rpTpLNZ5P0NSQAJWGisA8dodqyqU6Hz48M04D1oSGohowEstSKV5Zi2EGORfGvOApUKRSB6XVAasD3jLLubBcVl79OqBB8AY5L+jaD5m1ii8L4/4OawURQyR2x22GifQDJGE7b+Gzl3cct0arGRFFEHD6YVNBsA6eauxmRdgzBAV/g8bE9VZ4+pf/iuX0keXpe/LtHXm50taVVhs1r7S6girxeUFjQGOAGEHDvwOM/fszl4CGiIh1ryEQg9DM8eaYATi+eRd9ccxJaa0hVhFPm0MOSIiIKNGcYWyMQ+rggFOqYh6IKaBByaFh5ogIgrNmISjk5kgIpCEyDAkIlKVh2hgHYZwH9oc9h+OONM6kYaCZsVwXqJWQQDUS44BIwi1Qi9FwWim0UrCa8dafvwhIcAIQAHXHRdEhgTVCCFg1oAOiimLVMK+4gbTKoILGQEixB3luuDfwvmdDFGIK5NxwaxtoKV4ds4bDtj+NVlZqKdS1st4apUEtRm1GM8e8P4/+/37vTBzFN+e/gcFLbCiCuaC2/dysn5EGZo69XDP99wzpf937z72/pt+BDurQ113NEXEECCEgQWjWQdBa2wIx6ZgVFDz0r1vDtV+HN8OaI/Q9hAbEtmCqNVrJqPfP1NgDdwVQx71hLeOt0g93QNQQlw2A/cfwWhSN8jnAZrvnIgKqqEp/7kEREWpp1JJISYnxDxOvPwmY509PRApfffMVuZxoMbEuC6KRbJl0P/KX3/2Ga1mpU+X56T2X5cxbcW7rhev1xP7wBo1Cvhbu7l+z1pWvj79kXW4MYdguAqIOREA8IU3RdiPqjkt+JJfIaXlkKTd+fbhnr/AQKg9m7G6OXzPNA5HA/mFHmwfy0yNxF5nGwhiFu4fXlNuVmh2XSimN9VrQGIhhh2ljPswc3xyIMiCW8JioFggipFjAGs/v/m9OHz7w/O6Rj+8+kYvxfHVKuSE0TAvTAtOorLnQ5sjT85k4joRxT86GlkzaRX7xp7/k8voV0lbSPDC6YDGRT1dafcY0oEOgXY18uWH1e46vvmG9XCml4mmCpixPH0nXT6gLxRKtVbyAtcw4gIaZsmSqVopVamvcykox41aAKAQccj8QxEoaIhon7Fr5eFrYr4XdOKA6MqRIy0YzozShGuDGcs3Muxs5AjpRmTrQlJVSGzU7+ZaxsoJnasuEFikNMKctBTWBNGN1i8KjAQErua9tMJAIRDQEnAqueBNaK5gVQkjdObsQ6GAm4riCiiPan0+Ku+78Q2RcnT/7V/8pj+/f88Nf/yXv/kqodmIIRklXfH3C5UZDWUtBq6FjQ71na/8UbFkWxqSEoGyoiCPUtVBrwxFCUIZxAIyQIqpKWZxSBJFAiIGQenDnDqoBjZEYI2NKYNYDfVVyLdAcCZGQFHVHBMLnDCIQSsWlZyIxRAQwjBg7iE77HYf7O/Z3B6oJxZy8NtbSqLUSNBBCBzZvlbosUAMYuBtuFawSNkDR0LPVFJWo3RlL6A5UhwkPDW8NUCQEXAW3grftWnB0HNAQkBj7mzYHjaCRkCLDpKQpUt5fAMHo50ojJB1wM7z1zLOsmZILpRrNhFKN0jpgVnOaO76BJtLBzF/+iHzOmkQgCj3rRMC312xgSwOT/sQR70HkxsyISgczOgMgGtAQwA2hIQjWBLP+6aqKhISEiKYBV+nnDeugjYMrbN97tc+Aac174CGC4sh2bf1nDWr+vEb6U8AA8dZfUzNurQfF4ngD3DYQ7XsPjWhIIBHZEu2XiEKQ7UtFBFJUVEdsMGqNrLFnxn/IfhIwYzC8Ca/vZho72jhwu7zHcZo1lpb5/ukTEpUxjpxOz6SwQ8Q5356oDTQFrvWMasQFTtcn3hy/4mn9SNSEtX7oQki4BpbmaP3I3e6BajeWrJzziUt+Zk4HkhhzgDlOjHk
"text/plain": [
"<Figure size 576x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dblock1 = DataBlock(blocks=(ImageBlock(), CategoryBlock()),\n",
" get_y=parent_label,\n",
" item_tfms=Resize(460))\n",
"dls1 = dblock1.dataloaders([(Path.cwd()/'images'/'grizzly.jpg')]*100, bs=8)\n",
"dls1.train.get_idxs = lambda: Inf.ones\n",
"x,y = dls1.valid.one_batch()\n",
"_,axs = subplots(1, 2)\n",
"\n",
"x1 = TensorImage(x.clone())\n",
"x1 = x1.affine_coord(sz=224)\n",
"x1 = x1.rotate(draw=30, p=1.)\n",
"x1 = x1.zoom(draw=1.2, p=1.)\n",
"x1 = x1.warp(draw_x=-0.2, draw_y=0.2, p=1.)\n",
"\n",
"tfms = setup_aug_tfms([Rotate(draw=30, p=1, size=224), Zoom(draw=1.2, p=1., size=224),\n",
" Warp(draw_x=-0.2, draw_y=0.2, p=1., size=224)])\n",
"x = Pipeline(tfms)(x)\n",
"#x.affine_coord(coord_tfm=coord_tfm, sz=size, mode=mode, pad_mode=pad_mode)\n",
"TensorImage(x[0]).show(ctx=axs[0])\n",
"TensorImage(x1[0]).show(ctx=axs[1]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking and debugging a DataBlock"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAACyCAYAAAA9DtfXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZwlR3Xn+z2R291v1a29qnd1a29JCAkJhAwy2CzGYDPPCwN+XvmAn/1m/GaxZ/zejO2P8TL2G48Ze56x8dgMZmzDwDB4Y7HASCAJCQmt3Wp1q9eq7trr7vfmFhHvj7zVXd2mJZCF2qD7+3yyKjOWE5GZJ+OcOOdEXLHWMsQQQwwxxBBDvLihLnUHhhhiiCGGGGKIS4+hQjDEEEMMMcQQQwwVgiGGGGKIIYYYYqgQDDHEEEMMMcQQDBWCIYYYYoghhhiCoUIwxBBDDDHEEEPwIlYIROSEiLz2BW7z8yLyEy9km0N8c0FEflFEPvR11tklIlZE3Ivk/7yI/OHz08NvPETkfSLy7y51P4b4+vGNHldF5NUisvCNov9ix1cdQIYYYohvHVhrf/UbQVdEPg98yFr7vCob1tp3P5/0hhhiiK8NL1oLwRBD/GPDxWb432ptbmnb+VrSvg56wwnOEF8znm9++Vbgvxe7QnCziBwUkbqI/LGI5ABE5E0i8oiINETkXhG5brOCiPwbETkqIu1B3e/dkueIyH8UkTUROS4iP/0sptwfE5EnB+1/WkR2fuNveYjnGyLyr0XkYxek/Y6I/LaIzIrIX4jIhog8LSLv3FLmF0XkoyLyIRFpAT9yAQ1PRP5MRD4mIr6IvExEHhSRlogsi8hvXdCVt4vIqQH//d8XtPOhwfmme+HHReQU8LlB+q0DXm+IyKMi8upnuedfAW4HfldEOiLyu4P0K0Xkbwf3+5SIfP+WOh8Qkd8Tkb8RkS5wxzOkvWdLvWf6Hk+IyM+JyGNA91thUP5WgIioLWPluoh8RERqW/LfLCIHBu/08yJy1Za8EyLyr0TkMRFpisiHN8fmLWV+fsDnJ0Tk7VvSAxH5fwffwbJk7qf8IO/VIrIw4Jcl4I8H6T8rIosickZEfmLwfex9rvS+qWGtfVEewAngCWA7UAPuAd4D3AisALcADvDDg7LBoN73AbNkytQPAF1gZpD3buAgsA0YBe4ELOAO8j8P/MTg/HuAp4GryFw3/w9w76V+LsPjOfHSzIAPRgbX7oCHXgrcBfx/QA64AVgFXjMo94tAMuAFBeQHaR8anP818AHAGZS/D/ihwXkJuHVwvmvAZ+8f1LseiICrtrTzoQvKfhAoDsrPAevAGwf9+I7B9cSz3PdZfh5cF4F54EcHz+BGYA24ZpD/AaAJ3DZoJ/cMae8Z1Hm27/EE8AjZd5y/1LzwYj8G7+O1wM8AXxqMhQHw+8CfDcpcPvhevgPwgJ8djIX+FhoPkI2zNeBJ4N2DvFcDKfBbA7qvGtC6YpD/28BfDOqVgb8Efu2Cuv9hUDcPvB5YAq4BCsCfDL6Pvc+F3qV+/v/g93epO3CJGffdW67fCBwFfg/45QvKPgW86iJ0HgHeMjj/HPCuLXmv5eIKwSeBH99SVgE9YOelfjbD4znx0yeBdw7O30SmGG4HNFDeUu7XgA8Mzn8RuPsCOr84GIDuAv4zIFvy7gZ+CRi/oM6uAZ9t25L2APCDW2heqBDs2VL254A/uYDmp4EffpZ7PsvPg+sfAL5wQZnfB35hcP4B4IMX5F8sbVMheMbvcfAd/9ilfv/D4+y7OTEY955koPgO0mfIlF8X+HfAR7bkKeA08OotNN6xJf83gPcNzl9NJoSLW/I/MqApZMrBZVvyXg4c31I3BnJb8v+IgYAfXO8dfB97nwu9b/bjxe4ymN9yfpJMI90J/MuBKashIg2ygX0WQET+9y3mywZwLTA+oDF7Ac2t5xdiJ/DeLXQ2yBhw7vm4sSFecPw34B2D83eQzTRmgQ1rbXtLuZOc/46/Go/cClwH/LodjDwD/DjZ7OqQiHxZRN50Qb2lLec9MivCxbC13Z3A913A868kG8S/HuwEbrmAztuB6Yu0+0xpW2le9Hv8GuoPcWmwE/j4lnf2JJlyPEX27k5uFrTWGrJ3uPW7eCZerltru1uuN8fuCbJZ/kNb2v3UIH0Tq9bacMv1M43Zz4XeNzVe7P627VvOdwBnyBjiV6y1v3JhYcl8/O8HXgPcZ63VIvIImSAHWCQzkX01+hdis53//g/o/xD/ePC/gN8TkWvJLAQ/SzbTqIlIeYtSsINsNrSJr/Zzo58BHgM+KyKvttYuA1hrjwBvExEFvBX4qIiMPcf+bm13nsxC8M6LFf4aaGzSucta+x1fR52LpW2l+VW/x6+x/hCXBvNklpt7LswQkTPA/i3XQjZWnr6w7EUwKiLFLUrBDjL37xrQJ3NRXYzWhbzyTGP2c6H3TY0Xu4Xgp0Rk2yDY5eeBD5MJ/HeLyC2SoSgi3yUiZTIfqSXzAyMiP0pmIdjER4B/LiJzIjJCZoq9GN4H/FsRuWZAqyoi3/e83+EQLwgGs4SPAn8KPGCtPWWtnQfuBX5NRHKDYLgfB55VCbTW/saA1mdFZBxARN4hIhODGVVjUFQ/D93/EPDdIvI6yQJjc4OAqW3PUm8Z2LPl+q+Ay0XkhyQLiPRE5OatAWPPAc/0PQ7xjxfvA35lMIlCRCZE5C2DvI8A3yUirxERD/iXZDEv934d9H9JskDb28kU8P8x+C7eD/wnEZkctDsnIq97BjofAX5URK4SkQLw7zczniO9b2q82BWCPyWbjR0bHO+x1j4IvBP4XaBOFuzyIwDW2oPAfyQL7lom03K3asDv59zs7mHgb8j8XX9v0LbWfpwsGOXPJYswfwJ4w/N9g0O8oPhvZDzxJ1vS3kbmtz8DfJzMn/63Xwsxa+0vk1ke7hwora8HDohIB3gvWYzAP9hcOVBc3kKmFK+Sze7+Nc8+PrwX+N8kWyXznwdWkO8EfpDsfpc4F3D1XPt20e9xiH/UeC9ZLMxnRKRNFmB4C4C19ikyt9rvkM3Cvxv4bmtt/DXSXiLjhTNkyvW7rbWHBnk/R8YjXxqMq3cCV1yMkLX2k2SxOn83qHffICt6LvS+2SHnuyiHeD4hIm8gC4YZLid8EUBEdgCHgGlrbetS92eIIYb4+jCwZj1BtoolvdT9eaHxYrcQPK8QkbyIvFFEXBGZA36BbFY4xLc4Bn79fwH8+VAZGGKIbx6IyPcO3A+jZBatv3wxKgMwtBA8rxj4oO4CriQLRvlr4J8PBcS3NkSkSOZCOgm8fmCC/5bAwD3x1fAGa+0XXtDODDHENwAi8imy5YSabPz+P6y1i5e2V5cGQ4VgiCGGGGKIIYYYugyGGGKIIYYYYoihQjDEEEMMMcQQQ/AsGxP95V991AKDLRsVxhhEBGtBKYVSglKbP05mAUGpTMew1mAtiAjZvhPnYAc0N8sqpTjrurD2bHnZQkuQ7P+Anh2UE1Fn0xzHOS8PQLZ4RIwxf6+97D4UYMmSsvvY2ufN88161locx8l2pBCwWFIDiCDWIIAoB4UhLzGiIzzbwxEwNgFr8FyLEoXVg20jdUKaxijHRSkPxMFYS5LGJFphxUX5ZVBC3s8horCugyiHJInQOqXbWCSqL+GE7exexBCGKfNrMWfWunzlwCmKRY+b900xM1pkpFJAlMdqO2atJWyU57j1tlv59pe9FGOyJ+G5m+/Ann22W5/L1nRjzNl3urWsiBDHKUHgnc8ILwBE1KDDCjCAReFRK5V52xUVViNLwVdUA9Cp4FjNU/WIZs9y/cwIec+ADonFZdUE1PIBgQeum0eUppWAU6hRuf7bKVYnuOzam8mXKvilPL4StEk5/vRX+Km3/VOMidncw0qcABHQaYgg3PrGH+U3f+atjI35aGUxSYpBE8ch3ajPgae63LBnliW/wLaxUWZHPOJOG98BbQzaGlzPx1EOxljiXsxS1GOuWkKMJTExH/mLxymMzLDvmisR3wHtIQpKuTKGmLv+8n1Ea6c4c/AJPBMzUxXCehN
"text/plain": [
"<Figure size 648x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-03-19 13:21:55 +00:00
"dls.show_batch(nrows=1, ncols=3)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting-up type transforms pipelines\n",
"Collecting items from /home/jhoward/.fastai/data/oxford-iiit-pet/images\n",
"Found 7390 items\n",
"2 datasets of sizes 5912,1478\n",
"Setting up Pipeline: PILBase.create\n",
"Setting up Pipeline: partial -> Categorize\n",
"\n",
"Building one sample\n",
" Pipeline: PILBase.create\n",
" starting from\n",
" /home/jhoward/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_31.jpg\n",
" applying PILBase.create gives\n",
" PILImage mode=RGB size=500x414\n",
" Pipeline: partial -> Categorize\n",
" starting from\n",
" /home/jhoward/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_31.jpg\n",
" applying partial gives\n",
" american_pit_bull_terrier\n",
" applying Categorize gives\n",
" TensorCategory(13)\n",
"\n",
"Final sample: (PILImage mode=RGB size=500x414, TensorCategory(13))\n",
"\n",
"\n",
"Setting up after_item: Pipeline: ToTensor\n",
"Setting up before_batch: Pipeline: \n",
"Setting up after_batch: Pipeline: IntToFloatTensor\n",
"\n",
"Building one batch\n",
"Applying item_tfms to the first sample:\n",
" Pipeline: ToTensor\n",
" starting from\n",
" (PILImage mode=RGB size=500x414, TensorCategory(13))\n",
" applying ToTensor gives\n",
" (TensorImage of size 3x414x500, TensorCategory(13))\n",
"\n",
"Adding the next 3 samples\n",
"\n",
"No before_batch transform to apply\n",
"\n",
"Collating items in a batch\n",
"Error! It's not possible to collate your items in a batch\n",
"Could not collate the 0-th members of your tuples because got the following shapes\n",
"torch.Size([3, 414, 500]),torch.Size([3, 375, 500]),torch.Size([3, 500, 281]),torch.Size([3, 203, 300])\n"
]
},
{
"ename": "RuntimeError",
"evalue": "invalid argument 0: Sizes of tensors must match except in dimension 0. Got 414 and 375 in dimension 2 at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/TH/generic/THTensor.cpp:612",
"output_type": "error",
"traceback": [
"\u001b[0;31m----------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-8c0a3d421ca2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0msplitter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mRandomSplitter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m42\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m get_y=using_attr(RegexLabeller(r'(.+)_\\d+.jpg$'), 'name'))\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mpets1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;34m\"images\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/git/fastai2/fastai2/data/block.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(self, source, bs, **kwargs)\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0mwhy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_find_fail_collate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Make sure all parts of your samples are tensors of the same size\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwhy\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mwhy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mf\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mafter_batch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'noop'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m!=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/fastai2/fastai2/data/block.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(self, source, bs, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\nCollating items in a batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcreate_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mretain_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/fastai2/fastai2/data/load.py\u001b[0m in \u001b[0;36mcreate_batch\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mretain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mretain_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcreate_item\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mcreate_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mfa_collate\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfa_convert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprebatched\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdo_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcreate_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbefore_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mone_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/fastai2/fastai2/data/load.py\u001b[0m in \u001b[0;36mfa_collate\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m return (default_collate(t) if isinstance(b, _collate_types)\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32melse\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfa_collate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m else default_collate(t))\n\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/fastai2/fastai2/data/load.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m return (default_collate(t) if isinstance(b, _collate_types)\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32melse\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfa_collate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m else default_collate(t))\n\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/fastai2/fastai2/data/load.py\u001b[0m in \u001b[0;36mfa_collate\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfa_collate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m return (default_collate(t) if isinstance(b, _collate_types)\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfa_collate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m else default_collate(t))\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py\u001b[0m in \u001b[0;36mdefault_collate\u001b[0;34m(batch)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mstorage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melem\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_new_shared\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melem\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0melem_type\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__module__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'numpy'\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0melem_type\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'str_'\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0melem_type\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'string_'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 414 and 375 in dimension 2 at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/TH/generic/THTensor.cpp:612"
]
}
],
"source": [
"pets1 = DataBlock(blocks = (ImageBlock, CategoryBlock),\n",
" get_items=get_image_files, \n",
" splitter=RandomSplitter(seed=42),\n",
" get_y=using_attr(RegexLabeller(r'(.+)_\\d+.jpg$'), 'name'))\n",
"pets1.summary(path/\"images\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.491732</td>\n",
" <td>0.337355</td>\n",
" <td>0.108254</td>\n",
" <td>00:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.503154</td>\n",
" <td>0.293404</td>\n",
" <td>0.096076</td>\n",
" <td>00:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.314759</td>\n",
" <td>0.225316</td>\n",
" <td>0.066306</td>\n",
" <td>00:23</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cross entropy loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Viewing activations and labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x,y = dls.one_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorCategory([11, 0, 0, 5, 20, 4, 22, 31, 23, 10, 20, 2, 3, 27, 18, 23, 33, 5, 24, 7, 6, 12, 9, 11, 35, 14, 10, 15, 3, 3, 21, 5, 19, 14, 12, 15, 27, 1, 17, 10, 7, 6, 15, 23, 36, 1, 35, 6,\n",
" 4, 29, 24, 32, 2, 14, 26, 25, 21, 0, 29, 31, 18, 7, 7, 17], device='cuda:5')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"tensor([7.9069e-04, 6.2350e-05, 3.7607e-05, 2.9260e-06, 1.3032e-05, 2.5760e-05, 6.2341e-08, 3.6400e-07, 4.1311e-06, 1.3310e-04, 2.3090e-03, 9.9281e-01, 4.6494e-05, 6.4266e-07, 1.9780e-06, 5.7005e-07,\n",
" 3.3448e-06, 3.5691e-03, 3.4385e-06, 1.1578e-05, 1.5916e-06, 8.5567e-08, 5.0773e-08, 2.2978e-06, 1.4150e-06, 3.5459e-07, 1.4599e-04, 5.6198e-08, 3.4108e-07, 2.0813e-06, 8.0568e-07, 4.3381e-07,\n",
" 1.0069e-05, 9.1020e-07, 4.8714e-06, 1.2734e-06, 2.4735e-06])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds,_ = learn.get_preds(dl=[(x,y)])\n",
"preds[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(37, tensor(1.0000))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(preds[0]),preds[0].sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Softmax"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhV5b328e+PKYGEOWHKwCBBRpkiWIfWAc8BB6gzaLVUX1HrVLWn1WOPvrXVnlp7qlbqVKlKFRxbOYJ1qjMqg8xEIMxhygAkJJCQZP/ePxL7RgyygZ2svXfuz3XlMmvvxc59kb1vH571rLXM3RERkdjXLOgAIiISGSp0EZE4oUIXEYkTKnQRkTihQhcRiRMtgvrBKSkp3qtXr6B+vIhITFq4cGGhu6fW91xghd6rVy8WLFgQ1I8XEYlJZrbxYM9pykVEJE4cstDNbJqZ5ZvZ8oM8b2b2sJnlmtlSMxsR+ZgiInIo4YzQnwbGfsvz44Cs2q8pwKNHH0tERA7XIQvd3T8Edn7LLhOAZ73GZ0AHM+seqYAiIhKeSMyhpwGb62zn1T4mIiKNKBKFbvU8Vu8Vv8xsipktMLMFBQUFEfjRIiLylUgUeh6QUWc7Hdha347u/oS7Z7t7dmpqvcsoRUTkCEViHfos4AYzmwmMBordfVsEXldEJKaFQk5hWQU7iivYUVJO/p6a/54xoAvHpXeI+M87ZKGb2QzgVCDFzPKAu4GWAO7+GDAHOAvIBfYCP4p4ShGRKFReWU3err1s3rWPvF372LJrH1t372Nb8T627i5nR0k5VaFvzkCntk0IptDdfdIhnnfg+oglEhGJIhVV1Wwo3Mu6glLWF5WxobCMDYV72bizjB0lFV/bt1XzZnRrn0iPDomM7t2Jbu0T6dY+ka7tar66tE0gJTmBVi0a5pzOwE79FxGJJvurQuTml7J6xx5W7djDmh17yM0vZdPOvdQdZKe2TaBX5zackpVKZqc2ZHZqQ0an1qR3bENqcgLNmtW3TqRxqNBFpMkpq6hi5bYSlm8pZvmWElZsLWZtQSmV1TXN3aKZ0TsliUE92jN+WBrHpCZxTGoyvVKSSE6I3tqM3mQiIhEQCjlrC0pZuHEXizbtZkneblbv2POvUXdKcgKDerTj1GO7MKB7WwZ0b0evzkkNNi3SkFToIhJXKqtDLNtSzLz1O/l8XRFfbNpN8b5KADq0acnQ9A78+6BuHJfeniFp7enSLjHgxJGjQheRmOburNqxh4/XFPJJbiHz1u+kbH81AH1Skxg3uBsje3ZkZM+O9E5Jwiy4Oe6GpkIXkZhTWlHFR6sLeH9VAR+sLmB7STkAfVKSOG9EGt/pk8Ko3p1IbZsQcNLGpUIXkZiwo6Sct1Zs5+2cfD5bW8T+6hBtE1twSlYKp/brwslZKfTo0DromIFSoYtI1NpeXM7sZdt4Y9k2Fm7ahTv0Tkli8km9OKN/F0b27EiL5rF38LKhqNBFJKoU76vkjWXbeG3xVj5bX4Q79O/WllvG9GPc4G5kdW0bdMSopUIXkcBVh5xPcgt5eWEeb67YTkVViN4pSdx8RhbnDu3BManJQUeMCSp0EQnMjpJyXpy/mZnzN7Nl9z7at27JJcdncMGIdI5Lbx/XK1IaggpdRBqVuzN/wy6enrueN1fsoDrknNS3M3ec1Z8zB3YloUXzoCPGLBW6iDSK/VUh/nfJVqZ9sp4VW0tol9iCq07uzaRRmfROSQo6XlxQoYtIgyqtqGLmvE089fF6thWXk9UlmfvOG8L3h/egTStVUCTpb1NEGkTxvkqembuBpz5eT/G+Sk7o04n7zh/Cqf1SNTfeQFToIhJRJeWVPPXReqZ9sp495VWMGdCF60/ry/DMjkFHi3sqdBGJiH37q3n20w08+sFadu+t5N8HdeXG07MYnNY+6GhNhgpdRI5Kdch59Ys8fv/WaraXlPO9fqn89N+OZUi6iryxqdBF5IjNzS3kV7NzyNlWwtCMDjw0cRij+3QOOlaTpUIXkcOWt2sv987O4Y3l20nr0JqHJw3n3OO662BnwFToIhK2iqpqnvhgHVPfzwXgtjP7cfV3+5DYUicDRQMVuoiEZd76ndzx6lLWFpQxbnA3fnHOQNKa+OVqo40KXUS+VUl5Jb+Zk8OMeZtJ79iav/zoeE47tkvQsaQeKnQROaj3VuVzxyvLyN9TzjXf7cPNY7J0dmcU029GRL5hT3kl9/zvSl5amEdWl2Qev/wkhmZ0CDqWHIIKXUS+Zv6GndzywmK27t7Hj089hpvHZOkKiDFChS4iAFRWh3jonTX86f1c0ju24aVrT2RkT52uH0tU6CLClt37uGnGIhZu3MVFI9O5e/wgkhNUD7FGvzGRJu6dlTv46ctLqKwK8dDEYUwYlhZ0JDlCKnSRJqqqOsQDb63msQ/WMrB7O6ZeNkI3mohxKnSRJqiotIKbZi7ik9wiJo3K5O5zB+pszzigQhdpYpblFXPN9AUUlu3n/guP4+LsjKAjSYQ0C2cnMxtrZqvMLNfMbq/n+Uwze8/MFpnZUjM7K/JRReRovb50Kxc9Phcz45VrT1SZx5lDjtDNrDkwFTgTyAPmm9ksd19ZZ7dfAC+6+6NmNhCYA/RqgLwicgRCIefBd9fw8LtryO7ZkccuH0lKckLQsSTCwplyGQXkuvs6ADObCUwA6ha6A+1qv28PbI1kSBE5cuWV1dz20hJmL93GRSPT+fV5g3WiUJwKp9DTgM11tvOA0Qfs83+Bt8zsRiAJGBORdCJyVHaW7WfKswtYsHEXt4/rzzXf7aNrlsexcObQ6/vt+wHbk4Cn3T0dOAuYbmbfeG0zm2JmC8xsQUFBweGnFZGwbSwq44JH57J0SzFTLx3Btd87RmUe58Ip9Dyg7pGTdL45pXIV8CKAu38KJAIpB76Quz/h7tnunp2amnpkiUXkkJZvKeaCR+eye+9+Zlw9mrOP6x50JGkE4RT6fCDLzHqbWStgIjDrgH02AWcAmNkAagpdQ3CRAMzNLWTiE5+R0KI5L193IiN7dgo6kjSSQxa6u1cBNwBvAjnUrGZZYWb3mNn42t1uA642syXADGCyux84LSMiDewfy7cx+S/z6dEhkVeuO5FjUpODjiSNKKwTi9x9DjVLEes+dled71cCJ0U2mogcjlcW5vEfLy9hWEYHpk0+ng5tWgUdSRqZzhQViQPTP9vIf/19OSf17cyTV2TrrkJNlH7rIjHuyQ/Xce+cHM7o34Wpl43QNVmaMBW6SAx77IO1/PcbX3L2kO48OHEYLZuHdTUPiVMqdJEY9af3c7n/H6s4d2gP/nDxUFqozJs8vQNEYtBXZT5eZS516F0gEmP+/NE67v/HKiYM68H/qMylDr0TRGLI9E838OvZOZw1pBu/v0hlLl+nd4NIjHhxwWb+67UVjBnQhQcvGa4yl2/QO0IkBsxZto3bX1nKKVkpPHLpCFq10EdXvknvCpEo99GaAm6euYgRmR15/PKRWmcuB6VCF4liCzfuYsqzC+nbpS1PTT5eZ4DKt1Khi0SpNTv2cOXT8+naLoFnrxxF+9Ytg44kUU6FLhKFthXv44fT5tGqRTOmXzWa1La6/6ccmgpdJMoU761k8rT5lJRX8fSPjiejU5ugI0mMUKGLRJHyymqunr6A9YVlPHHFSAb1aB90JIkhOsIiEiVCIeenLy1h3vqd/HHScE485ht3cRT5Vhqhi0SJ+99cxetLt3H7uP6cO7RH0HEkBqnQRaLAc59v5LEP1nLZ6Eyu+W6foONIjFKhiwTsg9UF3PXaCk7v34Vfjh+EmQUdSWKUCl0kQGt27OGG576gX9e2PDxJ12eRo6N3j0hAikoruPKZ+SS2as5TP8wmOUFrFOToqNBFAlBRVc010xeSX1LBk1dk06ND66AjSRzQkECkkbk7d/5tOQs27uKRS4czLKND0JEkTmiELtLInvp4PS8vzOPmM7I45zgtT5TIUaGLNKL3V+Vz35wcxg3uxs1nZAUdR+KMCl2kkawrKOXGGYs4tls7fn/xUJo10/JEiSwVukgj2FNeydX
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(torch.sigmoid, min=-4,max=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"torch.random.manual_seed(42);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.6734, 0.2576],\n",
" [ 0.4689, 0.4607],\n",
" [-2.2457, -0.3727],\n",
" [ 4.4164, -1.2760],\n",
" [ 0.9233, 0.5347],\n",
" [ 1.0698, 1.6187]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acts = torch.randn((6,2))*2\n",
"acts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.6623, 0.5641],\n",
" [0.6151, 0.6132],\n",
" [0.0957, 0.4079],\n",
" [0.9881, 0.2182],\n",
" [0.7157, 0.6306],\n",
" [0.7446, 0.8346]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acts.sigmoid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.6025, 0.5021, 0.1332, 0.9966, 0.5959, 0.3661])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(acts[:,0]-acts[:,1]).sigmoid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.6025, 0.3975],\n",
" [0.5021, 0.4979],\n",
" [0.1332, 0.8668],\n",
" [0.9966, 0.0034],\n",
" [0.5959, 0.4041],\n",
" [0.3661, 0.6339]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm_acts = torch.softmax(acts, dim=1)\n",
"sm_acts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Log likelihood"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"targ = tensor([0,1,0,1,1,0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.6025, 0.3975],\n",
" [0.5021, 0.4979],\n",
" [0.1332, 0.8668],\n",
" [0.9966, 0.0034],\n",
" [0.5959, 0.4041],\n",
" [0.3661, 0.6339]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm_acts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.6025, 0.4979, 0.1332, 0.0034, 0.4041, 0.3661])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = range(6)\n",
"sm_acts[idx, targ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table ><thead> <tr> <th class=\"col_heading level0 col0\" >3</th> <th class=\"col_heading level0 col1\" >7</th> <th class=\"col_heading level0 col2\" >targ</th> <th class=\"col_heading level0 col3\" >idx</th> <th class=\"col_heading level0 col4\" >loss</th> </tr></thead><tbody>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow0_col0\" class=\"data row0 col0\" >0.602469</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow0_col1\" class=\"data row0 col1\" >0.397531</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow0_col2\" class=\"data row0 col2\" >0</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow0_col3\" class=\"data row0 col3\" >0</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow0_col4\" class=\"data row0 col4\" >0.602469</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow1_col0\" class=\"data row1 col0\" >0.502065</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow1_col1\" class=\"data row1 col1\" >0.497935</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow1_col2\" class=\"data row1 col2\" >1</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow1_col3\" class=\"data row1 col3\" >1</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow1_col4\" class=\"data row1 col4\" >0.497935</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow2_col0\" class=\"data row2 col0\" >0.133188</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow2_col1\" class=\"data row2 col1\" >0.866811</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow2_col2\" class=\"data row2 col2\" >0</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow2_col3\" class=\"data row2 col3\" >2</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow2_col4\" class=\"data row2 col4\" >0.133188</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow3_col0\" class=\"data row3 col0\" >0.99664</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow3_col1\" class=\"data row3 col1\" >0.00336017</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow3_col2\" class=\"data row3 col2\" >1</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow3_col3\" class=\"data row3 col3\" >3</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow3_col4\" class=\"data row3 col4\" >0.00336017</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow4_col0\" class=\"data row4 col0\" >0.595949</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow4_col1\" class=\"data row4 col1\" >0.404051</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow4_col2\" class=\"data row4 col2\" >1</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow4_col3\" class=\"data row4 col3\" >4</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow4_col4\" class=\"data row4 col4\" >0.404051</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow5_col0\" class=\"data row5 col0\" >0.366118</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow5_col1\" class=\"data row5 col1\" >0.633882</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow5_col2\" class=\"data row5 col2\" >0</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow5_col3\" class=\"data row5 col3\" >5</td>\n",
" <td id=\"T_74b4f58a_5696_11ea_af1d_53ec6ce9fc9drow5_col4\" class=\"data row5 col4\" >0.366118</td>\n",
" </tr>\n",
" </tbody></table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import HTML\n",
"df = pd.DataFrame(sm_acts, columns=[\"3\",\"7\"])\n",
"df['targ'] = targ\n",
"df['idx'] = idx\n",
"df['loss'] = sm_acts[range(6), targ]\n",
"t = df.style.hide_index()\n",
"#To have html code compatible with our script\n",
"html = t._repr_html_().split('</style>')[1]\n",
"html = re.sub(r'<table id=\"([^\"]+)\"\\s*>', r'<table >', html)\n",
"display(HTML(html))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.6025, -0.4979, -0.1332, -0.0034, -0.4041, -0.3661])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-sm_acts[idx, targ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.6025, -0.4979, -0.1332, -0.0034, -0.4041, -0.3661])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.nll_loss(sm_acts, targ, reduction='none')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Taking the `log`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAdHElEQVR4nO3de3xcdYH38c8vtzb3NPc2lyZp03tpadMWFkSEUosCFUUFXC+gi+4+yj6iz3pdebkurs+662WXx9WKIMsKKKwoAnITtCC0NC30njRtmjRpm3uTTO7JzO/5Y6al21vSzGTOnJnv+/WaF5POYc63p5lvTn7nd84x1lpERMS94pwOICIiwVGRi4i4nIpcRMTlVOQiIi6nIhcRcbkEJ1aam5try8rKnFi1iIhrbdu2rcNam3f6nztS5GVlZVRXVzuxahER1zLGNJ7tzzW0IiLicipyERGXU5GLiLicilxExOVU5CIiLqciFxFxORW5iIjLOTKPXEQkFlhr6egbobGzn4bOARo7+/lQVQkl2SkhXY+KXEQkCCfKuqGzn0Md/TR09NPYOUBDp/+/fcNjJ5eNM7CidIaKXETECT0Do9R39PkLu72fQ50DNHT4y/vUsk6IMxTPSKYsN5VVZdnMzkmhLDeVspxUirKSSUoI/Yi2ilxEJGBo1MvhrgHq2/uo7/AXdn2grLv6R04uF2egaEYyZTmpfGBFEeW5qczOTaU8J5XiGckkxIf38KOKXERiirWWds8wB9r7qG/vp769n4PtfdR39HHk+CC+U+5+mZ8+jfLcVNYtKqAiL5Xy3DTKc1MoyU5hWkK8c3+J06jIRSQqjXp9NHYOcLC9jwNtfRxs7+Ngez/1bX14ThkKSU6Mpzw3lWXFWdx4cTFz8lIpz/U/0qcnOvg3mDgVuYi42uCI92RZ17V5ONDmf97YOcDYKbvXhRnTmZOfyvsuLmJOXipz8tOoyEtjZsZ04uKMg3+D4KnIRcQV+obHqGv1UBco6hPPj3QPYgN9HR9nmJ2Twty8NN69uJC5+WnMyUujIs89e9eToSIXkYgyMDJGXWsf+wNFXdvioa7Vw9GeoZPLJCXEUZGbysWlM/jgyhIqC9KozE9jdk7qlMwKiXQqchFxxMiYj/oOf1Hvb/VQ2+Iv78NdAyeXSUqIY05eGqvKs5lXkE5lfhqVBemUZqcQ7/LhkFBSkYvIlLLWcqR7kNoWDzWBR21LL/Xt/SfHsBPiDOW5qSwtyuSmlcXMK0hnXoF/D1uFPT4VuYiETP/wGLWtHvYd66XmmIeall5qWjx4ht6eJVKUlcyCwnTWLixgfmE68wvTqchNi8khkVBRkYvIBbPWcrRniH1He9l7rJd9gUdj18DJA4/p0xKYX5jO+5YXMb8wnQWB0o7mg45OUZGLyHmNeX0cbO9nz9Ee9gaKe++xXroHRk8uMzsnhYWFGdx4cTELZ6azcGYGxTOSMUbDIuGgIheRk4ZGvdS2eNh9tIc9R3vZc6SHmhYPw2M+AKYlxLFgZgbXLpnJopnpLJqVwfzCDNKmqUqcFJKtb4y5H7gOaLPWLgnFe4rI1Boc8bL3WC+7j/Sw60gPu4/0UNfWhzdwADJjegKLZ2Xy0Utms7gog8WzMqnITQ37dURkfKH6Mfpz4F7gP0P0fiISQkOjXmpaPOxq7mZns7+4Ty3tnNQklhRlcvXCfJbMymRJUaaGRlwkJEVurd1kjCkLxXuJSHDGvD7q2vrY2dzNjuYedjZ3U9viYdT7dmkvLc7kmkUFLC3KZGlxJoUZ01XaLha2gS1jzB3AHQClpaXhWq1IVDsxR3tHUw87mrt563A3u470MDjqBSB9egIXFWfyqXdUsKw4k6XFWczKVGlHm7AVubV2I7ARoKqqyo6zuIicRf/wmL+wm7p587D/0dE3DPjPglw8K4MPryphWUkmy4qzKMtJdf0FoWR8OtQsEqGstTR0DrC98TjbDx9n++Fualt6T14vuzw3lSsqc1lemsXykiwWFGbopJoYpSIXiRBDo152NvewrfE42wLlfeKuNOnTE1heksU1V1WyIlDcWSlJDieWSBGq6YePAFcCucaYZuBua+3PQvHeItGqs2+Y6kBpb23oYveRnpMHJCtyU7l6QT4rZs9gRekMKvPTNEQi5xSqWSu3hOJ9RKJZ8/EBtjZ08cYh/+Ngez8ASfFxXFScye2Xl1M1O5uVs2eQnaq9bZk4Da2ITAFrLYc6+nnjUBdbAsV9pHsQ8J9oU1WWzQdWFrOqLJulRZlMT4yc+z+K+6jIRULAWkt9Rz+b6zvZXN/FlvpO2jz+2SS5aUmsLs/mjisqWF2ezfyCdA2TSEipyEUmqalrgNcOdvDawU5eP/h2ceenT+PSOTmsKc9hTUU2FbmpmrctU0pFLjJB7Z5hf3Ef6OS1+g6auvxDJblp/uK+tCKHSyqyKVdxS5ipyEXOYWBkjDcOdfFqXQevHuigpsUD+Me4L52Tw6cur+Av5uQwNz9NxS2OUpGLBPh8lr3HenmlroNX6tqpbjjOiNdHUkIcVbNn8Hfr53P53FwWz8rU7cckoqjIJaZ19g3zSl0Hm/a3s6munY4+/wk4C2dm8InLyrh8bi6ryrJJTtKsEolcKnKJKT6fZUdzNy/XtvOn2jZ2HunBWshOTeIdlblcUZnHOypzyc+Y7nRUkQlTkUvU6x0aZdP+dl7a18af9rfT2T+CMbC8JIvPr53HO+flsbQoU1MCxbVU5BKVGjv7eWFvK3/Y18bWhi7GfJaslETeOS+Pqxbkc0VlHjN09qRECRW5RIUTQybP723lxb2t1LX1ATCvII2/uqKCqxfks7wkS7cpk6ikIhfXGhnzsbm+k+f2tPDC3lbaPMPExxlWl2Vzy+pS1i4soDQnxemYIlNORS6uMjTq5U/723l2dwsv7mvFMzRGSlI8V87PY92iQt41P5/MlESnY4qElYpcIt7giJeXa9t4etcxXq5pY2DES1ZKIu9eXMj6xYVcXpmri05JTFORS0QaGvXyx9o2frfzGC/ta2Nw1EtuWhI3XlzEtUtmsqYim0SNd4sAKnKJIKNeH6/UtfO7Hcd4fk8L/SP+8v7AyiLes3Qma8pzdEalyFmoyMVRPp9l2+Hj/ObNIzyz6xjHB0bJTE7k+mWzuH7ZLNaUZ2umicg4VOTiiIPtfTyx/Qi/eesIzccHSU6M55pFBdywbBZXzMvTTYRFLoCKXMKme2CE3+04yuPbj7CjqZs4A5dX5vGFdfNYt6iQ1Gn6dhSZDH1yZEp5fZZX6tp5rLqZF/a2MuL1saAwna+9ZyEbls/SNU1EQkBFLlOiqWuAX1U38Vh1My29Q8xISeTWNaV8sKqYxbMynY4nElVU5BIyI2M+XtjbysNvNPLnA53EGXjnvDzuvn4RVy8s0Li3yBRRkUvQmroGePiNwzxW3URH3whFWcncdc08blpZzKysZKfjiUQ9FblMis9n+dP+dh7a3MjLtW0Y4OqFBdy6ppQrKvM031skjFTkckF6Bkd5rLqJhzY30tg5QF76ND73rrncvLpUe98iDlGRy4Qc6ujngT8f4vFtzQyMeKmaPYMvrJvP+sWFGvsWcZiKXM7JWsvm+i7ue6Wel2rbSIyL4/pls7jtsjKWFGnmiUikUJHLGca8Pp7Z3cJPN9Wz60gPOalJfO6qSv7yklLy0zXvWyTSqMjlpKFRL49VN/GTTfU0Hx+kIjeVb9+4lPevKNJlYkUimIpc8AyN8tDmRu5/9RAdfSNcXJrFN65bxNqFBbohsYgLqMhjWM/AKA+8dogH/txAz+AoV8zL42+unMOa8myMUYGLuIWKPAb1DI7ys1cP8cCrh/AMj7FuUQGfvWouFxVnOR1NRCZBRR5DPEOj3P9qA/e9Wo9naIxrlxRy59WVLJyZ4XQ0EQmCijwGDI16eej1Rn70xwMcHxjlmkUFfH7tPBbNUoGLRIOQFLkxZj3wQyAeuM9a+51QvK8Ex+uzPL6tie+/UEdL7xDvqMzli+vms6xEQygi0SToIjfGxAP/D7gGaAa2GmOetNbuDfa9ZXK
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(torch.log, min=0,max=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_func = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.8045)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_func(acts, targ)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.8045)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.cross_entropy(acts, targ)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5067, 0.6973, 2.0160, 5.6958, 0.9062, 1.0048])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.CrossEntropyLoss(reduction='none')(acts, targ)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Interpretation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAskAAALWCAYAAAC0tQ6jAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAJOgAACToB8GSSSgAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXxU5dn/8c8XEiEICUGRiKCAuIAbILtGwo4KskctteJTrdbHtlapD1rbWlurtvapVfvU1vYnUq01LCKoFQgQRFlEBZXF1rVCLWhLAkGWQHL9/pgTHMcQMjkhk2Gu9+vlKzPnnOtc37mHNjc3Z+bIzHDOOeecc859rlGiAzjnnHPOOdfQ+CTZOeecc865GD5Jds4555xzLoZPkp1zzjnnnIvhk2TnnHPOOedi+CTZOeeOYJJyJP1F0nuSNkh6XtKptTxXrqT1ktZKyoiz9jpJX6tN37okKU9S/2r2XyJpan1mcs41TPKvgHPOuSOTJAHLgcfM7OFgWzeghZktq8X5HgZWmdmjdZu0/ki6A9hpZvdVsS/NzPbXfyrnXEPkK8nOOXfkGgjsq5wgA5jZWjNbpohfSFon6S1Jl8KBldYiSTMlvS3pieDYq4F84IfBtjxJz1aeV9JDkiYHj+8JVq3flHRfsO0OSVOCx90krQz2Py0pO9heJOleSa9I+ruk3NgXFPRdKqkgOOYeSZOCmrcknRwcN0rSKklrJBVKaiOpA3Ad8N1gNTxX0jRJ/ytpCXCvpMmSHgrO8Uzl6rekayU9Ucfvj3OuAUtLdADnnHOHzZnAawfZNw7oBpwDHAuslvRisK87cAbwMfAycJ6Z/UHS+cCzZjZTUl5VJ5XUChgLnG5mJqllFYdNB75lZksl3Qn8CLgx2JdmZr0lXRRsH1JF/TlAF2Ab8D7wh6DmO8C3gnO9BPQNMlwN3GJmNwer4QdWkiV9HTgVGGJm5ZUT/cA3gJclfQDcDPQ9yFg6545AvpLsnHOp6XzgSTMrN7OtwFKgV7DvFTPbbGYVwFqgQxzn3QHsAf4gaRywK3qnpCygpZktDTY9BlwQdcjs4Odr1fRdbWb/MrO9wHvAgmD7W1E17YD5kt4Cvkdk0n8wM8ysPHZjMC4/BJYAN5vZtmrO4Zw7wvgk2TnnjlzrgXMPsk/V1O2NelxO1f/quJ8v/g5pChBc09sbmAWMAV6oadiY3gfrG5uvIup5RVTNg8BDZnYWcG1lvoP4rJp9ZwH/AdpWc4xz7gjkk2TnnDtyLQaaSLqmcoOkXpIGAC8Cl0pqLKk1kdXcV+I49z+ArpKaBKvDg4PzNweyzOx5Ipc9dIsuMrPtQHHU9cZXEFnFrmtZwD+Dx1dGbS8FWtTkBJJ6AxcSufxkiqSOdZrQOdeg+TXJzjl3hAquxx0L3B98rdke4EMik9cXgX7AG4ARuWZ3i6TTa3juTZIKgDeBd4A1wa4WwDOSmhJZrf5uFeVXAg9LakbkmuKravkSq3MHMEPSP4GVQOUEdx4wU9JoItcvV0lSE+AR4Coz+1jSzcD/kzTI/GuhnEsJ/hVwzjnnnHPOxfDLLZxzzjnnnIvhk2TnnHPOOedi+CTZOeecc865GD5Jds4555xzLoZ/u4Wj5Umn29Gt29e6vk+Hqm6oVTNSdV/VemSrqAj3odlGjZJz7MJ+WDiV/8w455yre++/9x5rXn/tS79cfJLsOLp1e3pfd3et66df0aPWtWmNU/cfM/bs+9INvuLSNL1xHSWpX/vLK0LVp/KfGeecc3Vv0uX5VW733zbOOeecc87F8Emy+4ITs5vys1Gn89OLT+P7wzrTLL0xPxpxCj+9+DTuvOhUWjc/qsbnKi0tZWBuP3KOyWTD+nVxZ5l6yxQG5+Uy+YpJlJWV1Wt9Inuvff01LhwygIuHDeSqr17Gvn376q132PqwvRP5ZyaR4xa2PlV7h6337KmX3cct9bKHqfVJckiSlku6LXh8h6SRdXTeyZL6VbP/fkkZddEr2j9L9nLbvLe5/bm/8c6nn9GnQ0sefPFDbn/ub8x6Ywtjz86p8bkyMjIomD2X0WPHx51j7Zo1bN2yhUVFy+jStSuzZ82st/pE9gY4vu0JzJr7V55bsISOnU7muXnPJEX2sL0hcX9mEv2eJ2t2HzfPnkzZfdxSL3vY3j5JDkFSe+AfwOC6PreZTTOzFdXsv9HMdtd13/KoD1U1SWvEpuLdbNsVWcksrzDK4/iwWVpaGq1bt65VjlUrVzBk6DAAhg4bwcoVy+utPpG9Adrk5NCsWTMA0tPTSWtc848OJPO4QeL+zCT6PU/W7D5unj2Zsvu4pV72sL19khzOBOBx4D1JnSu3SXpB0nxJmZK+K+lyAEldJE2T1E/SKklLJd0p6ShJ84LnL0pqWrkqLalDsFo9S9KbkoYE5yqS1FzSWZKWBMc8FOzLCzLMlbRW0lnxvKhzTsjkl2O6ctbxmWwp3QtAY4n87sfz3PpP6mrsqrV9ewktMjMByMrKorh4W73VJ7J3tE2bPqJoySKGX3RxvfVO5LiFlczvebJm93Hz7MmU3cct9bKH7e2T5HAGAwuAJ4GJwbaPzWwEMBO4BvgzcGmw76tEJtUXAT8xswHAHUB7YG/wfICZ7Ynpc0xwjvHA9TH73gUGmVl/oK2kU4Lt6WZ2CfA94KrY4JImSiqQVLC7eOsX9r3xzx3cPGcDyz8sZthpkVW9b+aexPy3Pz0waT7cWrbMpnTHDgBKSkrIzm5Vb/WJ7F1px44dXPf1K/nN7/5Ienp6vfVO5LiFlczvebJm93Hz7MmU3cct9bKH7e2T5FqS1A44G5gH3ApUXov8WvDzFaCzmW0Njj8OyAMWA78BhkqaDowws/eApZKmAT+VFPvdXuvMbD+wCciO2dcBeF7SUqAn0DbYvjb4WVUNZjbDzPLNLD8ju82B7WlR3727q2w/e/ZXMKHb8XxSupeX3y8+5LjUld59+lK4cAEAhQvn06//efVWn8jeAOXl5Xzjqiu45dbb6XzKqfXaO5HjFlYyv+fJmt3HzbMnU3Yft9TLHra3T5JrbwLwHTMbYWbDgLeBzkD3YH9PIqu8EFlp/jWw3MwqgO1m9h3gauBeSU2A35jZZKA1EPsuRl8IHPtl19cDDwar0K9G7a+u5qC6nZB54Jsszm6byYoPi7m0x/Gc1TaTn1x8Gl/teUJNTwXA+NEXs3jRQm64/loenz6txnXdunenTU4Og/Ny2bhhA2PHxfdBrjD1iewN8PSsGbyyagW/uOcuRg4fxOyZBUmRPWzvSon4M5Po9zxZs/u4efZkyu7jlnrZw/ZW2LtfpSpJy4DRZrYteP4V4FHgL0DlV0BMNLMdkpoCW4hcSvGGpBuBccDRRC7HmAf8kchfWnYQuXRjCpFJ7zrgPjObEJznBTPLk1REZPW6P/Ar4G9Ebg7zv0HvkWY2RdLpwNRgAl6lE3oOtTA3E/mT30ykVvxmIrWTyn9mnHPO1b1Jl+cza0aB33GvrphZbszzPxOZ8B7MWjN7Izj2fuD+mP25Mc/viHo8IajbQ+SSDcwsL9i3ADijin5FwXFvA5OryeWcc84552L4ksxhFnyQbiHwQKKzOOecc865mvGV5MPMzN7hy6vEzjnnnHOuAfOVZOecc84552L4SrKjT4eWTA/x4buRvz3ojQEP6YUb6vdrwurS9uBOhLWV1azm3398JPEP3jnnnEsG/tvKOeecc865GD5Jds4555xzLoZPkp1zzjnnnIvhk2R3UKWlpQzM7UfOMZlsWL+uRjUdj2nGg/lncf+EM7l7dBcy0hvx+JU9uH/Cmdw/4UzOPTGrxv2n3jKFwXm5TL5iEmVlZXHnD1MftjfA0zP/QtdObQ99YB33T+Zx8+zJl93HzbMnU3Yft9TLHqbWJ8kJJKmDpE8lFUl6VdJlh6nPZEk3xFuXkZFBwey5jB5b89s4flS8m28VvMWNM9fx9padnH/yMXxWVs6NM9dx48x1vPbR9hqdZ+2aNWzdsoVFRcvo0rUrs2fNjCt7mPqwvQEqKip4ds5sTjihXdy1icyerL09e+r19uy
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"interp = ClassificationInterpretation.from_learner(learn)\n",
"interp.plot_confusion_matrix(figsize=(12,12), dpi=60)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('american_pit_bull_terrier', 'staffordshire_bull_terrier', 10),\n",
" ('Ragdoll', 'Birman', 6)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interp.most_confused(min_val=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Improving our model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Learning rate finder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>8.946717</td>\n",
" <td>47.954632</td>\n",
" <td>0.893775</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>7.231843</td>\n",
" <td>4.119265</td>\n",
" <td>0.954668</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(1, base_lr=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEKCAYAAADn+anLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hc1Z3G8e9v1LssS3KRZcuyjQs2bsI0mwUMBEKvgUBCDUsgkIRdNmGzySYkAbKhBbJZekkgCYQWajAlgA0Y3Htv2JZkS7bVu3T2D41AGFVbM3dG836eZx7P3Ll37is9sn4699xzjjnnEBER6YjP6wAiIhK6VCRERKRTKhIiItIpFQkREemUioSIiHRKRUJERDoV7XWA3srMzHR5eXlexxARCSuLFi0qdc5l9fa4sCsSeXl5LFy40OsYIiJhxcy2HchxutwkIiKdUpEQEZFOqUiIiEinVCRERKRTKhIiItIpFQkREemUisRBqqhrZFNJldcxREQCQkXiIF3/9GJOuvt9fv/uBlpatDaHiPQvKhIHYd6GUuZuKGVUVjJ3zlnPFU8sYG91g9exRET6jIrEAWppcfzmH2vJSU/glRtm8utzJvLxpj2cdt9cVhWWex1PRKRPqEgcoNdXFrFiZzk3nXQI8TFRXHLECF647mhanOOmZ5bR1NzidUQRkYOmInEAGptbuPPNdYwdlMLZU3M+3z4xJ41fnDmRdbsq+dP87qdJ2bi7kgVb97Jg614Wbt3L1tLqQMYWEem1sJvgLxT8dcF2tu6p4bHLC4jy2Zfe+9qhgzj2kCzunrOe0w8bSlZK3FeOr2ts5rbX1/DHj79cSMzgiqNHcvPXxpIQGxXQr0FEpCdUJHqppqGJ+97ZwIy8DI4fm/2V982M/z5jAqfc+wG/+cda7rxg8pfe37Crkhv+soS1xZVcecxIThiXjaP1rqg5q3bx2IdbeG/dbn57wWSmjxgQlK9JRKQzKhK99NryIkoq67n/4qmYWYf7jMpK5qqZ+Tzw/iYunjGcacPT2VRSxavLi3jg/U0kx0XzxBWHc9x+RWbWmCxOmTiY/3huORc88BHXHTeaH5w4hugoXRUUEW+Yc+F1b39BQYHzcj2JSx6Zz/a9tbx/83GdFgmA6vomZt/1PjHRRozPx2Z/f8MJ47K547xJZKfEd3psZV0jv3x1Nc8u3MGR+Rncd/HULvcXEemOmS1yzhX09jj9idoLxeV1fLRpD2dPzemyQAAkxUXz8zMPZXdFPTkDEvjl2ROZf8tsHrv88G5/4afEx/A/50/mzgsms3R7GV//3Tw+3rSnL78UEZEe0eWmXnh52U6cg7OnDO3R/qdMHMy6X516wOc7f/owJuakct3Ti7nkkfn88uyJXHLEiK/st21PNT4zhg1I6LZ4daSmoYl5G0p5b30J+ZlJfPuoPGKj9feDiAS4SJjZVqASaAaa9m/qmNlxwN+BLf5NLzjnbg1kpoPx4pJCJuemk5+VHLRzjhucysvfm8kNf17MT15cSWllAzfOHo2ZUd/UzN1z1vPQ3M04BxlJsUzKSWPysDQm5qRx2LB0BqXGYWaU1TSwuqiCDbuqqKxrpLaxmdqGFrbtqWbexlLqm1pIiImitrGZP3/yGT87Y8LnfSYVdY0s3raP8tpGxg1OJT8riRj1k4hEhGC0JI53zpV28f5c59zpQchxUNYWV7CmqIKfnzEh6OdOjovmoW8X8KPnl3PP2+sprarn4hnDuenZpawtruTiGcOZMDSV5dvLWL6jnLkbSmibRiozOY6YKKOovO5LnxnlMxJioshMjuXiGcM5acIgZozMYN6GUm59dTWXP76AI0ZmUF7byLpdlbTvuoqN9jEmO5nkuGhanKOpxdHc4mhsdjQ0NdPY7EiMjSI3I5FhAxIYnpFIwYgMDh2ais/X+5aOiHhHl5t66KUlhUT5jNMn9+xSU1+LifJx1wWTyUqO48EPNvOn+dvITI7lscsLOGHcoNadjmy9FFXT0MTqwgpW7Cxnxc5ynINxg1MYNySVcYNTGJAYS0yUdXhp6vhx2Rw9eiCPf7iVpz/ZRt7AJE6dOITD8wYwICnWXywrWVtcSV1jMzFRPuJjjCifERPlIzbKR3SUUVXXxLY91Xy4sZSahmYABiTGMHNMFuMGp1BSWU9hWS27KuoYNiCRUycN5oRx2STG6kdSJJQE9O4mM9sC7AMc8KBz7qH93j8OeB7YARQC/+6cW9XB51wDXAMwfPjw6du2dT+auS+1tDhm/uZdxg5O4fErZgT13B3548dbWbGjnB+fOo6ByV8drBdKnHPsrqzn4017+GBDCXM3lFJSWU9yXDRD0uIZnBbPmqJKSqvqiY/xcWT+QJqaHSWV9f5tUcwak8lxY7M4ZnQmKfExXn9JImHpQO9uCnSRGOqcKzSzbOAt4Abn3Aft3k8FWpxzVWb2deB3zrkxXX2mF7fAzt+8h4sems/vLprCWVNyuj9AOuWco7qhmeS4L1oMzS2OBVv38saKIj7evIekuGiykuPITIljX3UD8zaUUlnfRLTPGJIez6CUeLJT48hOiSczOZaslDgyk+PIzUhkZOaX+0tKq+r5dMtedlfUMX1EBhOGpn5llLxIJDjQIhHQtr1zrtD/724zexGYAXzQ7v2Kds9fN7M/mFlmN30YQffSkp0kxUZx8oTBXkcJe2b2pQIBrf0jR+YP5Mj8gR0e09jcwuJt+5i3sZTP9tawu6KedcWVzN1QSmVd05f2besvGZmZxLriSjbs/vKCUClx0Rw+MoOJQ1PJy0wiLzOJkQOTGJAU27dfqEg/EbAiYWZJgM85V+l/fjJw6377DAZ2Oeecmc2gddxGSA0IKCqv5aWlOznjsKGaT8kjMVE+jsgfyBEdFJG6xmZKq+opqaxn655q1hRVsqaogqXbyxiVlcy504ZxRH4Gg1LjWbh1L/M37+WTLXt4b91u2q8RlZEUy+isZEZlJzEkLYH4GB9x0VHERftIjIsmKTaKpLhoMpNjGZWVfEC3GouEo0C2JAYBL/r/M0UDf3bO/cPMrgVwzj0AnA9818yagFrgIhdiQ8B/++Y6WhzcOLvLq2DikfiYKIYNSGTYgESmDh/AOVM73zdnSs7nlwvrm5rZvreWraXVbCmtZlNJFZtKqvjHymL21TR2ec7hGYl8fdIQTps0hIk5qSoY0q9pWo4urNxZzun3z+Nf/yWfW04dH5RzivcamlpoaG6hvrGZuqYWahuaqK5vprq+iW17a3hjZTEfbSylqcWREhfN8IGJDM9IJDcjkczkWNITYxmQGMuY7GTyMpO8/nJEgBDtkwhnzjl+9dpqBiTGcN1xo72OI0EUG+0jNtr3lb4TgKOBi2cMp6ymgTmrd7FqZznb9tawblcl76zZTUO7xaaifcZdF07WzQ4S1lQkOvHOmt3M37yXX5x5KGkJuu1Sviw9MZYLC3KhIPfzbc45quqbKKtpZG91A7e/sYbv/3Upe6sbuOKYkR6mFTlwmluhA43NLdz2xhryM5P45hHDvY4jYcLMSImPITcjkcm56TxxxQy+duggfvHKau6as45wu7QrAioSX7GvuoGb/7aMzSXV/PjUcZqjSA5YfEwUf7hkOhcdnsv9727kR88vp1Frn0uY0eUmP+ccLyzeya9fX0NFbSPfO340J00Y5HUsCXNRPuP2cyeRnRLHfe9uZGdZLX+4ZLouYUrYUJGg9W6Wq55cwNwNpUwbns5t505i3OBUr2NJP2Fm3HTyWIYPTOKWF5Zz/v99xGOXH05uRqLX0US6pWspwJuripm7oZRbTh3Hc9cerQIhAXH+9GE8eeUMdlXUcc4fPmRNUUX3B4l4TEUC+NP8bQzPSOQ7s/I1lbUE1NGjMnnhumOI8hmXPvIJG/ebNkQk1ER8kVhbXMGnW/Zy6ZHDVSAkKEZnJ/P01Udi1rpm+rY91V5HEulUxBeJp+ZvIzbaxwXTc7vfWaSPtBWKhqYWvvnwJ+zYV+N1JJEORXSRqKxr5MXFrZP3aRZQCbaxg1P401VHUFHXyKWPfEJJZb3XkUS+IqK
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"lr_min,lr_steep = learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Minimum/10: 8.32e-03, steepest point: 6.31e-03\n"
]
}
],
"source": [
"print(f\"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.071820</td>\n",
" <td>0.427476</td>\n",
" <td>0.133965</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.738273</td>\n",
" <td>0.541828</td>\n",
" <td>0.150880</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.401544</td>\n",
" <td>0.266623</td>\n",
" <td>0.081867</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(2, base_lr=3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Unfreezing and transfer learning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.188042</td>\n",
" <td>0.355024</td>\n",
" <td>0.102842</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.534234</td>\n",
" <td>0.302453</td>\n",
" <td>0.094723</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.325031</td>\n",
" <td>0.222268</td>\n",
" <td>0.074425</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fit_one_cycle(3, 3e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.unfreeze()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(1.0964782268274575e-05, 1.5848931980144698e-06)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEKCAYAAAAIO8L1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxcdb3/8dcnaZIuSdeke9KmNF2ha0AQyr4U0RbFpXhVuKKIgOiD31XBDQW5Kl69chEvVMXlKlZAvRbsZVHZodCkC9CNtmmzdE2bpEmzL5/fHzNthzBtk3ZOZiZ9Px+PeTRnm/PJeUzm3XO+3/M95u6IiIh0lhLvAkREJDEpIEREJCoFhIiIRKWAEBGRqBQQIiISlQJCRESi6hPvAmIlOzvbx48fH+8yRESSSnFx8V53z4m2rNcExPjx4ykqKop3GSIiScXMSo+0TJeYREQkKgWEiIhEpYAQEZGoFBAiIhKVAkJERKJSQIiISFQKCBGRJLZiWxXFpdWBvLcCQkQkif3nM29z1xPrAnlvBYSISBIrqaxnQs6AQN5bASEikqTqm9vYVdvEhGwFhIiIRNi6tx6ACTmZgby/AkJEJEltqTwAkJyXmMxsvpltNLPNZnbbEdb5qJmtM7O1ZvZwxPxrzGxT+HVNkHWKiCSjksp6zGD8sGACIrDRXM0sFbgfuASoAFaY2VJ3XxexTgFwO3C2u1eb2fDw/KHAHUAh4EBxeNtg+nKJiCShkr31jBncj75pqYG8f5BnEGcAm929xN1bgCXAwk7rfBa4/+AXv7vvCc+/DHjG3avCy54B5gdYq4hI0tm690Bg7Q8QbECMAcojpivC8yJNAiaZ2ctmttzM5ndjWxGRk5a7s7WyPrAeTBDsA4MsyjyPsv8C4HxgLPCimZ3axW0xs+uB6wHy8vJOpFYRkaSyu7aZ+pZ2TgmogRqCPYOoAHIjpscCO6Ks81d3b3X3rcBGQoHRlW1x98XuXujuhTk5UZ+YJyLSK5WEezDlZyfnJaYVQIGZ5ZtZOrAIWNppnf8FLgAws2xCl5xKgKeAS81siJkNAS4NzxMREWDLoXsgkvASk7u3mdnNhL7YU4GH3H2tmd0JFLn7Ug4HwTqgHfiyu+8DMLO7CIUMwJ3uXhVUrSIiyaak8gD90lIZObBvYPsIsg0Cd18GLOs071sRPztwa/jVeduHgIeCrE9EJFmVVNaTnz2AlJRoTbaxoTupRUSS0Na9wQ3Sd5ACQkQkyTS3tVNR3RDoPRCggBARSTql+xrocALt4goKCBGRpHOwi+uEALu4ggJCRCTpbKkMdXEdn90/0P0oIEREkkxJZT3DszLI6psW6H4UECIiSaZk74HAezCBAkJEJKm4e/g51MG2P4ACQkQkqVQ3tLK/sTXQUVwPUkCIiCSRgz2YTtEZhIiIRCqpDH6QvoMUECIiSWTL3gOkpRpjhwTbxRUUECIiSaWksp5xwwaQGuAgfQcpIEREkkhJ5YEeaaAGBYSISNJoa++grCr4QfoOUkCIiCSJiupGWtu9RxqoQQEhIpI0SvYe7OKqgBARkQiHurgGPIrrQYEGhJnNN7ONZrbZzG6LsvxaM6s0s9Xh12cilrVHzF8aZJ0iIslgS2U9Q/qnMWRAeo/sL7BnUptZKnA/cAlQAawws6Xuvq7Tqn9095ujvEWju88Kqj4RkWRTUnmgxxqoIdgziDOAze5e4u4twBJgYYD7ExHp1Ur21pPfQ11cIdiAGAOUR0xXhOd1dpWZvWFmj5lZbsT8vmZWZGbLzezKaDsws+vD6xRVVlbGsHQRkcSyv7GVyrrmHuvBBMEGRLTb/LzT9OPAeHefAfwd+E3Esjx3LwQ+DvzEzE5515u5L3b3QncvzMnJiVXdIiIJ582K/QCcNmZQj+0zyICoACLPCMYCOyJXcPd97t4cnvw5MDdi2Y7wvyXAc8DsAGsVEUloaypqAJgxZnCP7TPIgFgBFJhZvpmlA4uAd/RGMrNREZMLgPXh+UPMLCP8czZwNtC5cVtE5KSxuryGCdkDGNQ/2MeMRgqsF5O7t5nZzcBTQCrwkLuvNbM7gSJ3XwrcYmYLgDagCrg2vPlU4EEz6yAUYt+P0vtJROSk4O6sLq9h3sTsHt1vYAEB4O7LgGWd5n0r4ufbgdujbPcKcFqQtYmIJIud+5uorGtmZm7PXV4C3UktIpLw1pSH2h8UECIi8g6rK2pISzWmjsrq0f0qIEREEtya8hqmjRpIRp/UHt2vAkJEJIG1dzhvVuxnVg9fXgIFhIhIQttSeYD6lvYeb38ABYSISEJbXRafBmpQQIiIJLTVFTVk9e1D/rCeG4PpIAWEiEgCW1New8yxg0lJiTa8XbAUECIiCaqptZ0Nu+qYmdtzA/RFUkCIiCSotTv2097hzModEpf9KyBERBLU6vLQEN8zx+oMQkREIqwur2H0oL4MH9g3LvtXQIiIJKg15TVx6d56kAJCRCQBVdW3UFbVoIAQEZF3OvgEuXgMsXGQAkJEJAGtKa8hxXr2GdSdKSBERBLQ6vIaCoZnMSAj0Oe6HZUCQkQkwbh7uIE6fmcPEHBAmNl8M9toZpvN7LYoy681s0ozWx1+fSZi2TVmtin8uibIOkVEEkl5VSPVDa1xbaCGAJ9JbWapwP3AJUAFsMLMlrr7uk6r/tHdb+607VDgDqAQcKA4vG11UPWKiCSK1eEG6plj4xsQQZ5BnAFsdvcSd28BlgALu7jtZcAz7l4VDoVngPkB1SkiklBWllbTLy2VKSN79hGjnQUZEGOA8ojpivC8zq4yszfM7DEzy+3OtmZ2vZkVmVlRZWVlrOoWEYmrlWXVzMwdRJ/U+DYTB7n3aGPTeqfpx4Hx7j4D+Dvwm25si7svdvdCdy/Myck5oWJFRBJBY0s763bUMndcfAboixRkQFQAuRHTY4EdkSu4+z53bw5P/hyY29VtRUR6ozcqamjr8F4fECuAAjPLN7N0YBGwNHIFMxsVMbkAWB/++SngUjMbYmZDgEvD80REerXislBfnNlxGuI7UmC9mNy9zcxuJvTFngo85O5rzexOoMjdlwK3mNkCoA2oAq4Nb1tlZncRChmAO929KqhaRUQSxcrSaibkDGDIgPR4lxJcQAC4+zJgWad534r4+Xbg9iNs+xDwUJD1iYgkEndnZVkNF00ZHu9SAN1JLSKSMLbta6CqviUh2h9AASEikjCKS0PtD3MUECIiEqm4tJqsvn2YmJMZ71IABYSISMJYWVrNnLwhpKREuxWs5ykgREQSQG1TK2/vqUuY9gdQQIiIJITVZTW4o4AQEZF3Ki6tJsWI+xDfkRQQIiIJYGVZNZNHDiQzjk+Q60wBISISZ+0dzuqyGuaOS5yzB1BAiIjE3aY9ddQ1tyVU+wMoIERE4u7QDXJ5CggREYlQXFpNdmY6eUP7x7uUd1BAiIjE2aqyGubkDcEsMW6QO0gBISISR/sONLN1b33CtT+AAkJEJK5WltUAiTNAXyQFhIhIHBWXVpOWapw2ZlC8S3kXBYSISBytLK1m2uhB9E1LjXcp7xJoQJjZfDPbaGabzey2o6z3YTNzMysMT483s0YzWx1+PRBknSIi8dDc1s6aihpOT8DLSxDgI0fNLBW4H7gEqABWmNlSd1/Xab0s4BbgtU5vscXdZwVVn4hIvL21vZbmtg4Kxw+NdylRBXkGcQaw2d1L3L0FWAIsjLLeXcA9QFOAtYiIJJyibVUAFI5PzDOIIANiDFAeMV0RnneImc0Gct39iSjb55vZKjN73szmBViniEhcrNhWTX72ALIzM+JdSlRBDhsY7Y4PP7TQLAX4T+DaKOvtBPLcfZ+ZzQX+18ymu3vtO3Zgdj1wPUBeXl6s6hYRCZy7U1xaxcVTR8S7lCMK8gyiAsiNmB4L7IiYzgJOBZ4zs23AmcBSMyt092Z33wfg7sXAFmBS5x24+2J3L3T3wpycnIB+DRGR2NtSWU91QyunJ2j7A3QxIMzsFDPLCP98vpndYmbHGpd2BVBgZvlmlg4sApYeXOju+909293Hu/t4YDmwwN2LzCw
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.263579</td>\n",
" <td>0.217419</td>\n",
" <td>0.069012</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.253060</td>\n",
" <td>0.210346</td>\n",
" <td>0.062923</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.224340</td>\n",
" <td>0.207357</td>\n",
" <td>0.060217</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.200195</td>\n",
" <td>0.207244</td>\n",
" <td>0.061570</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.194269</td>\n",
" <td>0.200149</td>\n",
" <td>0.059540</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.173164</td>\n",
" <td>0.202301</td>\n",
" <td>0.059540</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, lr_max=1e-5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Discriminative learning rates"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.145300</td>\n",
" <td>0.345568</td>\n",
" <td>0.119756</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.533986</td>\n",
" <td>0.251944</td>\n",
" <td>0.077131</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.317696</td>\n",
" <td>0.208371</td>\n",
" <td>0.069012</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.257977</td>\n",
" <td>0.205400</td>\n",
" <td>0.067659</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.246763</td>\n",
" <td>0.205107</td>\n",
" <td>0.066306</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.240595</td>\n",
" <td>0.193848</td>\n",
" <td>0.062246</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.209988</td>\n",
" <td>0.198061</td>\n",
" <td>0.062923</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.194756</td>\n",
" <td>0.193130</td>\n",
" <td>0.064276</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.169985</td>\n",
" <td>0.187885</td>\n",
" <td>0.056157</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.153205</td>\n",
" <td>0.186145</td>\n",
" <td>0.058863</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.141480</td>\n",
" <td>0.185316</td>\n",
" <td>0.053451</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.128564</td>\n",
" <td>0.180999</td>\n",
" <td>0.051421</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.126941</td>\n",
" <td>0.186288</td>\n",
" <td>0.054127</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.130064</td>\n",
" <td>0.181764</td>\n",
" <td>0.054127</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>0.124281</td>\n",
" <td>0.181855</td>\n",
" <td>0.054127</td>\n",
" <td>00:25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fit_one_cycle(3, 3e-3)\n",
"learn.unfreeze()\n",
"learn.fit_one_cycle(12, lr_max=slice(1e-6,1e-4))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3xV5f3A8c83e0NIAgmEPWVERlgOhlILWpVWVFy1Vmtta6vtz4HV1rrqaqvVOmutC6TOigqiLEcZEmSFHSCQECAhEDLIvs/vj3PuSHKT3EBCxv2+X6/7yrnPGfc5Xjzf+2wxxqCUUsr/BLR2BpRSSrUODQBKKeWnNAAopZSf0gCglFJ+SgOAUkr5qaDWzkBTxMfHmz59+rR2NpRSql1Zt27dEWNMQu30dhUA+vTpQ1paWmtnQyml2hUR2ectXauAlFLKT2kAUEopP6UBQCml/FS7agNQSqmmqqysJDs7m7KystbOSosLCwsjOTmZ4OBgn47XAKCU6tCys7OJjo6mT58+iEhrZ6fFGGPIz88nOzubvn37+nSOVgEppTq0srIy4uLiOvTDH0BEiIuLa1JJRwOAUqrD6+gPf6em3qdfBIAP12czd43XbrBKKeW3/CIALNiQw3/WZrV2NpRSfqigoIDnn3++yeddeOGFFBQUtECO3PwiAASI4NCFb5RSraC+AFBdXd3geQsXLqRz584tlS3AT3oBiQgOR2vnQinlj+bMmcPu3bsZOXIkwcHBREVFkZSUxIYNG9i6dSszZ84kKyuLsrIybrvtNm6++WbAPfVNcXExM2bM4JxzzmHlypX06NGDjz76iPDw8FPOm18EgABBSwBKKR74eAtbcwqb9ZpDu8dw/8XD6t3/2GOPkZ6ezoYNG1ixYgUXXXQR6enprq6ar776Kl26dKG0tJSxY8dy2WWXERcXV+Mau3bt4u233+af//wnV1xxBe+//z7XXnvtKefdTwKAoM9/pVRbMG7cuBr99J955hk+/PBDALKysti1a1edANC3b19GjhwJwJgxY8jMzGyWvPhHAAjQEoBSigZ/qZ8ukZGRru0VK1awZMkSVq1aRUREBFOmTPHajz80NNS1HRgYSGlpabPkxS8agUWEag0ASqlWEB0dTVFRkdd9x48fJzY2loiICLZv387q1atPa978ogQQqFVASqlWEhcXx9lnn83w4cMJDw+nW7durn3Tp0/nxRdfJCUlhcGDBzNhwoTTmje/CADaCKyUak3z5s3zmh4aGsqiRYu87nPW88fHx5Oenu5Kv+OOO5otX35RBaTjAJRSqi6/CAA6DkApperyiwAQINZUqUoppdz8JAAIDn3+K6VUDf4RAALQbqBKKVWLfwQAEa0CUkqpWnwKACIyXUR2iEiGiMzxsv93IrJVRDaJyFIR6W2nTxWRDR6vMhGZae97TUT2euwb2by35qZVQEqp9iIqKgqAnJwcZs2a5fWYKVOmkJaWdsqf1eg4ABEJBJ4DvgdkA2tFZIExZqvHYeuBVGPMCRH5BfAEcKUxZjkw0r5OFyAD+NzjvDuNMe+d8l00QscBKKXam+7du/Peey37ePSlBDAOyDDG7DHGVADzgUs9DzDGLDfGnLDfrgaSvVxnFrDI47jTxuoG2ngAOHi8lIoq7S+qlGo+d999d431AP70pz/xwAMPcP755zN69GhGjBjBRx99VOe8zMxMhg8fDkBpaSmzZ88mJSWFK6+8stnmAvJlJHAPwHM5rWxgfAPH3wh4G9o2G/hbrbRHROSPwFJgjjGm3If8NJkvs4FmHT3BuU8s54ejevDUlS1WG6WUak2L5sChzc17zcQRMOOxenfPnj2b22+/nV/+8pcAvPPOO3z22Wf89re/JSYmhiNHjjBhwgQuueSSetf0feGFF4iIiGDTpk1s2rSJ0aNHN0vWfSkBeMuR18epiFwLpAJP1kpPAkYAiz2S7wGGAGOBLsDd9VzzZhFJE5G0vLw8H7JbV2NVQCt25HLuE8sB+HD9gZP6DKWU8mbUqFHk5uaSk5PDxo0biY2NJSkpid///vekpKQwbdo0Dhw4wOHDh+u9xldffeWa/z8lJYWUlJRmyZsvJYBsoKfH+2Qgp/ZBIjINuBeY7OWX/BXAh8aYSmeCMeagvVkuIv8GvE5wYYx5GXgZIDU19aQq8gMCGp4N9Cf/Xnsyl1VKtTcN/FJvSbNmzeK9997j0KFDzJ49m7lz55KXl8e6desIDg6mT58+XqeB9lRf6eBU+FICWAsMFJG+IhKCVZWzoFbGRgEvAZcYY3K9XOMq4O1a5yTZfwWYCaR7Oa9ZNLUXUFllw2t1KqVUU8yePZv58+fz3nvvMWvWLI4fP07Xrl0JDg5m+fLl7Nu3r8HzJ02axNy5cwFIT09n06ZNzZKvRgOAMaYKuBWr+mYb8I4xZouIPCgil9iHPQlEAe/aXTpdAUJE+mCVIL6sdem5IrIZ2AzEAw+f4r3Uq6lTQRw63nAkPhXHSyu5672NFJZVcvB48zTkKKXatmHDhlFUVESPHj1ISkrimmuuIS0tjdTUVObOncuQIUMaPP8Xv/gFxcXFpKSk8MQTTzBu3LhmyZdP00EbYxYCC2ul/dFje1oD52ZiNSTXTj/P51yeoqaWAKb8ZQWZj13UInl5c1Um76Rls2x7LkeKK1h8+yQGJ0a3yGcppdqOzZvdjc/x8fGsWrXK63HFxcWAtSi8cxro8PBw5s+f3+x58pORwPU3AvvSPbS5VFU7+MvnOwE4UlwBQG5Ry5U2lFKqIX4RAMTuBuqtGqjSnif6nAHxPN3C3T/TcwrrpJWUa3uDUqp1+EUACLBbz70VAiqrrcRJg+LpE+9erLkl5g46cKxunX9+SYsMfVBKefCXucCaep9+EgCsv966glZVWyWAoIAAEmPCXOnOwNCcducV10l7esmuZv8cpZRbWFgY+fn5HT4IGGPIz88nLCys8YNt/rEmsB0BvLUDOB/0wYFCt5hQV3ppZTUhQc0bHzPzS+qk5RWVszuvmP4JUc36WUopS3JyMtnZ2ZzsQNL2JCwsjORkbzPxeOcfAaCBKqBv9x4FIDgwABHhgqHd+HzrYcoqq+kUHtys+Sguq2JIYjQXDEvkmaXuX/7n//VLdjw8ndCgwGb9PKUUBAcH07dv39bORpvkV1VA3koAv5r3XY3304cnAlBa0fyNs6WV1YSHBPLbaQN5YlYKT85yD+e+8qXVzf55SinVEL8qATTU47PCbguICbN+9R8vraz/4JNUWlFNREggIsIVqT2pdhjSMo/xn7QsNmQVYIxpkeHeSinljV+UAJzP1B2Hiuo9przSCgCxkVYAOHaiolnzUFpRzcbsAsKD3dU8gQHC47NSuOOCQQAUnGj+oKOUUvXxiwDgLAFc9sLKeo8pr7KqfDpHhADN/zD+zfz1VFZ7/4XfN95qAD5UqIPClFKnj18EgODAxqtVyu2FYGLtANCcJYC0zKN8sdWa6vVERVWd/YmdrN5Hv3hrnY4MVkqdNn4SAOq/TWdPn3MGxAMQFWo1i5SU131Qn6xX/7fXtV1ZVbchomu01W83M/8Ejy/a0Wyfq5RSDfGLRmDPAFBV7SAoMID84nKqjWFQtyiCAgIY3y8OgJCgAEKDAihqxgAQGeL+zzxtaNc6+7t5DEDbfqjudBFKKdUS/KME4DGgq6zKwZHicsY8vIRxjywlt6iciJCa/e+jQoMoKvM9APz+w828vy673v0OA907hfHxrefws3P71dkfEhTAQzOttT93HS7WdYmVUqeFfwSAAHcbQFllNTOf+5/r/b78E4TVCgD5JRXMW7Pfp6HjG7MKmLdmP//37kZ2Hfbey6ii2kFIUAAjkjvV283zugm9efaqUVRUO/jpa2u5/MX6G6yVUqo5+EcA8KgCKq2oJrvWpGyeXTM9FftQDbRqT75r+3tPfeV1eumKKt+mlUjsZFUFfZNxhLWZx1pkMJpSSjn5RwDwePg6u3t6qh0A7vz+YAA+2XSwzrG11X5Ie6s6qqw2PgWAhKjQGu8PFJxo9ByllDpZfhEAPCtdyr3Ur4f
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_loss()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Selecting the number of epochs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Deeper architectures"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.427505</td>\n",
" <td>0.310554</td>\n",
" <td>0.098782</td>\n",
" <td>00:21</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.606785</td>\n",
" <td>0.302325</td>\n",
" <td>0.094723</td>\n",
" <td>00:22</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.409267</td>\n",
" <td>0.294803</td>\n",
" <td>0.091340</td>\n",
" <td>00:21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.261121</td>\n",
" <td>0.274507</td>\n",
" <td>0.083897</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.296653</td>\n",
" <td>0.318649</td>\n",
" <td>0.084574</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.242356</td>\n",
" <td>0.253677</td>\n",
" <td>0.069012</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.150684</td>\n",
" <td>0.251438</td>\n",
" <td>0.065629</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.094997</td>\n",
" <td>0.239772</td>\n",
" <td>0.064276</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.061144</td>\n",
" <td>0.228082</td>\n",
" <td>0.054804</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from fastai2.callback.fp16 import *\n",
"learn = cnn_learner(dls, resnet50, metrics=error_rate).to_fp16()\n",
"learn.fine_tune(6, freeze_epochs=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Why do we first resize to a large size on the CPU, and then to a smaller size on the GPU?\n",
"1. If you are not familiar with regular expressions, find a regular expression tutorial, and some problem sets, and complete them. Have a look on the book website for suggestions.\n",
"1. What are the two ways in which data is most commonly provided, for most deep learning datasets?\n",
"1. Look up the documentation for `L` and try using a few of the new methods is that it adds.\n",
"1. Look up the documentation for the Python pathlib module and try using a few methods of the Path class.\n",
"1. Give two examples of ways that image transformations can degrade the quality of the data.\n",
"1. What method does fastai provide to view the data in a DataLoader?\n",
"1. What method does fastai provide to help you debug a DataBlock?\n",
"1. Should you hold off on training a model until you have thoroughly cleaned your data?\n",
"1. What are the two pieces that are combined into cross entropy loss in PyTorch?\n",
"1. What are the two properties of activations that softmax ensures? Why is this important?\n",
"1. When might you want your activations to not have these two properties?\n",
"1. Calculate the \"exp\" and \"softmax\" columns of <<bear_softmax>> yourself (i.e. in a spreadsheet, with a calculator, or in a notebook).\n",
"1. Why can't we use torch.where to create a loss function for datasets where our label can have more than two categories?\n",
"1. What is the value of log(-2)? Why?\n",
"1. What are two good rules of thumb for picking a learning rate from the learning rate finder?\n",
"1. What two steps does the fine_tune method do?\n",
"1. In Jupyter notebook, how do you get the source code for a method or function?\n",
"1. What are discriminative learning rates?\n",
"1. How is a Python slice object interpreted when past as a learning rate to fastai?\n",
"1. Why is early stopping a poor choice when using one cycle training?\n",
"1. What is the difference between resnet 50 and resnet101?\n",
"1. What does to_fp16 do?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Find the paper by Leslie Smith that introduced the learning rate finder, and read it.\n",
"1. See if you can improve the accuracy of the classifier in this chapter. What's the best accuracy you can achieve? Have a look on the forums and book website to see what other students have achieved with this dataset, and how they did it."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
2020-03-19 13:21:55 +00:00
"jupytext": {
"split_at_heading": true
},
2020-03-06 18:19:03 +00:00
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
2020-03-19 13:21:55 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
2020-03-06 18:19:03 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}