fastbook/clean/06_multicat.ipynb

1447 lines
641 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Other computer vision problems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi-label classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.vision.all import *\n",
"path = untar_data(URLs.PASCAL_2007)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fname</th>\n",
" <th>labels</th>\n",
" <th>is_valid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>000005.jpg</td>\n",
" <td>chair</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>000007.jpg</td>\n",
" <td>car</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>000009.jpg</td>\n",
" <td>horse person</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>000012.jpg</td>\n",
" <td>car</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>000016.jpg</td>\n",
" <td>bicycle</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fname labels is_valid\n",
"0 000005.jpg chair True\n",
"1 000007.jpg car True\n",
"2 000009.jpg horse person True\n",
"3 000012.jpg car False\n",
"4 000016.jpg bicycle True"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'train.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: Pandas and DataFrames"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fname 000005.jpg\n",
"labels chair\n",
"is_valid True\n",
"Name: 0, dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[:,0]\n",
"df.iloc[0,:]\n",
"# Trailing :s are always optional (in numpy, PyTorch, pandas, etc),\n",
"# so this is equivalent:\n",
"df.iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['fname']"
]
},
2020-04-23 13:41:55 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a\n",
"0 1\n",
"1 2\n",
"2 3\n",
"3 4"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df1 = pd.DataFrame()\n",
"df1['a'] = [1,2,3,4]\n",
"df1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 11\n",
"1 22\n",
"2 33\n",
"3 44\n",
"dtype: int64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df1['b'] = [10, 20, 30, 40]\n",
"df1['a'] + df1['b']"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Constructing a data block"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dsets = dblock.datasets(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object, fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('005620.jpg', 'aeroplane')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(get_x = lambda r: r['fname'], get_y = lambda r: r['labels'])\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('002549.jpg', 'tvmonitor')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return r['fname']\n",
"def get_y(r): return r['labels']\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-22 03:06:19 +00:00
"(Path('/home/sgugger/.fastai/data/pascal_2007/train/008663.jpg'),\n",
" ['car', 'person'])"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return path/'train'/r['fname']\n",
"def get_y(r): return r['labels'].split(' ')\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PILImage mode=RGB size=500x375,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1) ['dog']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = torch.where(dsets.train[0][1]==1.)[0]\n",
"dsets.train.vocab[idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PILImage mode=RGB size=500x333,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def splitter(df):\n",
" train = df.index[~df['is_valid']].tolist()\n",
" valid = df.index[df['is_valid']].tolist()\n",
" return train,valid\n",
"\n",
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y)\n",
"\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y,\n",
" item_tfms = RandomResizedCrop(128, min_scale=0.35))\n",
"dls = dblock.dataloaders(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAACyCAYAAAA9DtfXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9abBkyXXf9zuZeddaXtVbu3t6mxUzA3AAiCRAghtAEyapCIbJsGjRDitEiXLQMoM2w6YsS8EPdlg2rZAiKNmyJdkhOgzSpkSRVNCkDAoghwAXEAsxIDCDmZ6tp/fu9/qttd4tM/0h73v9MJwejLBMk2D9I6r7Vd26tzLznjx58pz/OVe89yywwAILLLDAAn+2oe51AxZYYIEFFlhggXuPhUGwwAILLLDAAgssDIIFFlhggQUWWGBhECywwAILLLDAAiwMggUWWGCBBRZYgIVBsMACCyywwAILsDAIFljg3woicklEvus1Pv82EXn+q/i7HxGRv/bVuv4CCyywwMIgWGCBrwC897/rvX/LvW7HAgss8IW4mxG/wB/HwiB4EyAi5l63YYEFvhx8tWV4MUcWWODeY2EQvAGIyBkR+RURuS0iOyLyj0TkQRF5sn2/LSL/t4gMjp1zSUT+poh8DpguFN7XFL5RRJ4VkT0R+T9FJBWR94rItcMv3EVmEhHZFZGvO/a9dRGZi8ha+/7fE5E/EpGRiLwsIt/zWg0Qkb8qIs+1bfg3InLu2LFfF5H/5tj7t4rIh9vf3hSRv91+/i4R+QMR2ReRm20b42PneRH5MRF5EXjxLtf2IvKfi8jFdh78PRFRx46/Xju/4PoS8DMisiUiByLyORF5W/vdJRH5QDuel0Xkpw5/R0R+WER+T0T+fvs7r4jI934pN3aBP9n4t9XFIvJzwFng10RkIiL/9b3twZ9weO8Xr9d5ARr4LPAzQAdIgW8FHgLeDyTAGvA7wD84dt4l4I+AM0B2r/uxeH3F5OES8Ex7X5eB3wf+DvBe4NrryUx77H8D/u6x6/0XwK+1f78LOGjlSgH3AY+2xz4C/LX27+8HXgIeAwzwU8DH7tLeHnAT+K/advSAd7fHvh74pvYa54HngJ84dq4HPtz28zVluP3Ob7ffOQu88Ebb+errA98NfBoYANKed7L97geAX23bf779nR9pj/0wUAP/STv2fx24Aci9lpfF6ys6974cXfxd97r9fxpe97wBf9JfwDcDtwHzRb73/cBnjr2/BPzVe93+xesrLg+XgP/02Ps/D7zMFxoEd5UZ4N3AVUC17/8Q+A/av/8p8DN3+d3jBsEHDxfD9r0CZsC51zjvPzwul1+kbz8B/Ktj7z3wnV/kHA98z7H3/xnwW2+kna++PvCd7UL/TYfj036ugRJ4/NhnPwp8pP37h4GXjh3L22ufuNfysnh95V5fpi5eGARv4LUIGXxxnAEue++b4x+2rt5/LiLXRWQE/Dyw+qpzr75ZjVzgTcXx+3oZOPWq468pMwDe+08AU+A7RORRwu7m/z123stv4PfPAf+wdfXvA7uEHfV9r/Hdu15TRB5pQwC3Whn+H/nSZPhu4/FG2nl0rvf+SeAfAf8rsCki/7uI9Ns2xe21j//O8evcOnadWftn9w20fYE/PfhydPECbwALg+CL4ypw9jU4AD9N2IU84b3vA/8xQdkdx+JRkl+bOHPs77ME9/Rx3E1mDvF/EeTlLwG/5L0vjp334Bv4/avAj3rvB8demff+Y3f57t2u+Y+BC8DDrQz/bb40Gb7beLyRdn7B9b33/7P3/uuBtwKPAH8D2CaEBM4d++pZ4PobaNsCXzv4UnXxQg+/QSwMgi+OTxJisP+TiHRaAtm3EGKZE2BfRO4jKK4F/mzgx0TktIgsExbRf/Gq43eTmUP8HPADBMX1gWOf/zPgr4jIvyMiSkTua70Ir8Y/Af6WiLwVjgh3P3h4UELNgv+2ffvrwAkR+YmW1NgTkXe3x3rACJi0v/PXv1jHX3XtQ/wNERmKyBkCJ+JwPF63na9x7W8UkXeLSETwohSA9d5b4BeB/6Ft/zngvyTsBBf4s4MvVRdvAg+8uU3904mFQfBF0Cqj7yO4dq8A14C/CPx3wJ8jkMD+NfAr96qNC7zp+H+ADwEX29ffOX7wdWTm8Pg14CnCzuV3j33+SeCvEEhTB8BH+cJd8eH3/hXwd4F/3rpInwGOs+rPEMiOeO/HBMLV9xHc6i8C72u/95PAfwSMgf+DP27YvBaOrn0Mv0ogA/4RYS78szfYzlej37ZjjxAS2AH+fnvsxwlGwkXg9wj34GffQHsX+BrBl6GLfxr4qTZ09ZNvXov/9EFa0sUCCyzwJkJEfha44b3/qa/wdU8D/9J7/81fyeve7doi4gkhh5e+0r+3wAILvLlYGAQLLPAmQ0TOE3bT7/Tev3JvW/PlYWEQLLDA1w4WIYMFFngTISL/PcF1/vf+tBsDCyywwNcWFh6CBRZYYIEFFlhg4SFYYIEFFlhggQUWBsECCyywwAILLECoL35XPPPCM14AEUFeFVkQBMThgDuPMqkR54kwvPTC5/iHP/PTbN+e8+D5PwfAN7/nOyDNiBOFlQaUZWMpByAf7fKv/+n/QrZ3QDQuGEhCFx2OZT16S8sk3RTfTZifWGJ3bcB3/+UfBsCmXSbzgr3dTX7+536W3//Ex7h/LTyJ9oHzDzAaHYB4RuMDZtMpZTlnPpsDMJtMwDu89YjyeEBhAfjBv/DtPPvsy1y4uEXlNMqCIW17WiFxAxasj6hLy4//2I/z6FseC9ctKtJeF6803ggoh7XhupHWiKtw9YSPfORJnvrDT+LqCm+rMIy2AQ/eOpz14MM9APBa0VgL3qOtxTc1rj0vUYbzKycZ39hnMi6ZW+F2UTJt706UCqunllk91SfpKFZWu3Q6oT9KDE2jyPMhUZyC0ngclOF3h86wc+F5trb3efS976NKFbMmjGGSden3h8RJjslzUOqovdgGOy+IGseP/NhPvrrozVcdpx9+yFe1o8aQRxEAcZIym01D22yN9w7bNG1zLVma4nGURYEowRhDnIVx8h6s8zjr8E0D4nEuTA5vBfEO7ywmSVBKodq5Ucwm4b5qhdYR1nrwjqOaKV6xlA/5ru94L5/8zKe4fO0Kvj05Eg2JJk0SvAJvS5xYvArzQ0Qh1uOtRfkKFSeYJAPATRuyzhqq2yM2JcvLQ1R7F8aTA5xrUDqiKivEV/hK0djQpm6esbN1m6X1UywtD2lmY+r2nu/vX6GT1iQmJjEr+Mgxt7d4x3vOA/Du73obf/ChCZcuvMK5M29HzAlWOusA9FfOszWa0MxGrK706J++j60ytPczF29waXOHg7lF9zZYWTvBqdhyrnwegDV5Bd1UKFvTqBLvhWuXdwCYFBMefPhxLl0vGKycBrHMJ3us9EJ/O3mE9w3GGEQ0Rsdo3apAJZg4RekIk2QonWCiuJWXCK01SkX85R/6njddhgF+50dPeyce7Q/wtB1SFV4iFBqwRBJRSw2A4PFK0D7F6ynS9HDMER10kKq7+GiK8gRdo2rwQRejp3iXosThVE1URbiTJwGofY6ejWG4Rtlb49rVHZ55dp8/fGUTgDRZJ+vlnFlf5uWXX2LYTxjPghzv7Rc0zYz7V/osr6YcbM6J0ERpkDedRXR7DXFkmE40kXI0lQMgSlPwKTev32LWFGRNn715kMUsj+gud2m2R5h+is4iDnZKuu293Xc1ppgzHjfoTsQwj1B1OHeuYuoGZr6ijhybozmNzjESxml3VtAjwvRikljz1sfP8dDZDQD6wxVOrZ1ieX2Zcl5w4/J1QmYkvPTCi8xGMzbOrnLz+hbzckQaxUd60aSeJF1mOOySpQ2+1thGt33NyZdynFKIj/DNDFs1QV0AjgKcwjnB+QbvGg71iLce6wumU8XFV24z2h1Ttvc8QpA4xij45Q9//DXl+HUNAudCuSeRYwtSyznQIohXx4wB8AIiHucc4/GIsiqpa0+S9wFo8IgoKuvQEaRJwtpaUBKTyR6V9yTa0IiicI641Vy5OJp6jql
"text/plain": [
"<Figure size 648x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-03-19 13:21:55 +00:00
"dls.show_batch(nrows=1, ncols=3)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Binary cross entropy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 20])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.train.one_batch()\n",
"activs = learn.model(x)\n",
"activs.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-22 03:06:19 +00:00
"tensor([ 2.0258, -1.3543, 1.4640, 1.7754, -1.2820, -5.8053, 3.6130, 0.7193, -4.3683, -2.5001, -2.8373, -1.8037, 2.0122, 0.6189, 1.9729, 0.8999, -2.6769, -0.3829, 1.2212, 1.6073],\n",
" device='cuda:0', grad_fn=<SelectBackward>)"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"activs[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(inputs, targets):\n",
" inputs = inputs.sigmoid()\n",
2020-04-22 15:14:42 +00:00
" return -torch.where(targets==1, inputs, 1-inputs).log().mean()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.0082, device='cuda:5', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_func = nn.BCEWithLogitsLoss()\n",
"loss = loss_func(activs, y)\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Hello Jeremy.', 'Ahoy! Jeremy.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def say_hello(name, say_what=\"Hello\"): return f\"{say_what} {name}.\"\n",
"say_hello('Jeremy'),say_hello('Jeremy', 'Ahoy!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Bonjour Jeremy.', 'Bonjour Sylvain.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = partial(say_hello, say_what=\"Bonjour\")\n",
"f(\"Jeremy\"),f(\"Sylvain\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.903610</td>\n",
" <td>0.659728</td>\n",
" <td>0.263068</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.724266</td>\n",
" <td>0.346332</td>\n",
" <td>0.525458</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.415597</td>\n",
" <td>0.125662</td>\n",
" <td>0.937590</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.254987</td>\n",
" <td>0.116880</td>\n",
" <td>0.945418</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.123872</td>\n",
" <td>0.132634</td>\n",
" <td>0.940179</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.112387</td>\n",
" <td>0.113758</td>\n",
" <td>0.949343</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.092151</td>\n",
" <td>0.104368</td>\n",
" <td>0.951195</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
"learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(#2) [0.10436797887086868,0.93057781457901]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.1)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(#2) [0.10436797887086868,0.9416930675506592]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.99)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"preds,targs = learn.get_preds()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorMultiCategory(0.9554)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_multi(preds, targs, thresh=0.9, sigmoid=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXSc9X3v8fdXu7VZq2VsWYvBC3IwGIQdIGBDApicxgQcEkhJgNuGktS3p03IDZz25uY6l5KFNEkvtCnNJYSkKTG0ITRljbETB0NqES9gvAlLtmRhayRZthZrne/9Y8ZmLGRrwJJGmvm8ztHhWX4z850H+aPf/J7fPI+5OyIiEr+SYl2AiIiMLQW9iEicU9CLiMQ5Bb2ISJxT0IuIxLmUWBcwVFFRkVdUVMS6DBGRSeW1115rcffi4fZNuKCvqKigpqYm1mWIiEwqZrbvVPs0dCMiEucU9CIicU5BLyIS5xT0IiJxTkEvIhLnFPQiInFOQS8iEucm3Dx6kXg2GHRaO3s5dLSXQ0d7ONTRQ2tnH9npKRRmp1GYlR76b3YaBZlppCSrLyZnTkEvMgJ3p7N3gMNd/XT3DzAw6PQNBukfCDIQfGe5f9AZCAbpGwjSNxikrbOPQx0974T60R4CHb0E38MtIPIyUynMOvkPQElOBiW5GUzLTackN7Scn5mKmY3dQZBJTUEvCcfdOXpsgEBnD81He2np6uNwVx9tXX0c7u6jdcj64a5++gaD7+u1CrLSmJYTCuT503PCAZ1BSc47IV2QlUZ33wAtnX20dvbS1tVHS9c7y62dfbR09lLb3Mmre3s53N3/rtdJS06iOCedkojwL8nNoKwgk/LCTGYVZDJ1SuqZHjqZpBT0Elf6BoLsa+1if1s3gY5emjt6ae7oObEcCP/0Dgwf3HmZqRRkppGflUZpfibnl+aRn5VGQVYq+ZlpZKWnkJqcRGqyhf87/HJaShJ5mamkpyRHVXdaShp5mWmcMy17xLa9A4M0Hw29r3c+LfTSHB4K2tPcye9qW+joGXjXeysvyKSsMIuygimUF2RRVhj6Q1CSk0FSkj4RxCsFvUxKPf2D1LV0sae5k9pDHexp7mRPcyf1LV0MDBkbyctMZVpOOsU56VxcUUBxTvqJ9eKcdIqy0ynISiNvSuqkGBNPT0lmVkGol346HT39NLQdY39bF/tau9nX1k1DWzdbG9p55vW3GYw4TqnJxrSc8HBQTgYluemhTx65Ge98SsjJIHdKioaIJiEFvUxo7k7TkR627G9ne9ORULA3d7KvtevEWHeSQUVhFudMy+baBSXMmZZDeWEm03IzKMpOi7pXHW9yMlKpmpFK1Yzcd+3rHwzydnsP+8J/BA60H+PQ0dBQ1luBTja+1cLRIZ8IANJTkk76A1mUnU5xdhpFOekUZ6dTdGJ7Gtnp+qMwUSjoZULp6OlnW+MRtjS0n/gJdPQCoV5nZVEW556Vw8fOn8GcadnMKcmmsigrYcP8/UpNTqKsMJOywkwunzN8m2N9g0OGh0I/zR29tHT2sr+1mz/sO0xbdx8+zAnm9JQkKouyOL80j4WzpnJ+aR7zpueQOgk+NcUbBb3ETDDo7Dh4NBTo+0OhXhvoPBEas4uyuPycIi4oy+OCWXnMn55LWopCYrxMSUumvDCL8sKs07YbGAzS1t1HS0fopPHxn0BHL7sPdfLCmwf5eU0DEAr/BTNyWVga+n+6sHQqFYVZOj8wxhT0Mq5aOnvZsCfA+l0BNuxpoa2rD4D8zFQumJXHHy2cEQr20jymZmqWyGSQkpwUGt/PyRh2v7vT0HaMrY3tbG1oZ1vjEX6+qYFHN9YDkJuRwsLSPC4sy+PiygIuLMsnK13RNJrMh/vMNbSR2XLg+0Ay8EN3/8aQ/eXAI0Ax0Abc6u6N4X1lwA+BWYADH3X3+lO9VnV1tevGI/FjYDDIloZ2frM7FO6vHzgCQGFWGlfMLebyOUVcVJ5PWUGmxnMTyMBgkNpAJ1sb2tnaeIQt+9vZefAoQYfkJGPBjFyqywtYXJlPdUUBRdnpsS55wjOz19y9eth9IwW9mSUDu4GrgUZgE3CLu78Z0eYJ4Ffu/mMzuwq4w90/E963HrjP3V80s2wg6O7dp3o9Bf3kd+hoD7/ZFeA3uwNs2BPgaM8ASQYXluWzbF4xS+dOY8GMXH1cl5N09PSzeX87m+rb+K+6NrY0tJ+YBju7OIuLywu4uLKAxRUFzCqYoo7BEKcL+mg+Hy0Gat19b/jJHgeuB96MaFMF/FV4eR3wVLhtFZDi7i8CuHvn+3oHMuH19A/y3BsHWVPTwMa3WgEoyU1n+Qems2zeNC47u0hDMXJaORmpXDG3mCvmhm572jswyBsHjrKpvo1NdW08t/2dsf7ZRVlcXVXCNQtKuGBWPsnqNJxWNEE/E2iIWG8ElgxpsxVYSWh45wYgx8wKgblAu5n9O1AJ/Bq4x90HIx9sZncCdwKUlZW9j7chseDuvH7gCGtqGvjlliY6egYoK8jki1fP5ZoFJcwryVGvS9639JRkLirP56LyfO5aejbBoLOnuZNX97by6x2HeOTlOv7pt3spyk7jI+eWcHVVCZedU0RGqmZgDRVN0A/3L3XoeM/dwINmdjvwW+AAMBB+/suBRcB+4OfA7cD/O+nJ3B8GHobQ0E3U1UtMtHX18dTmA6ypaWDnwQ7SU5L46Hln8cnqWSypLNCQjIyJpCRj3vQc5k3P4bZLKzja08/6XQFe2H6QX217m8c3NZCZlswVc4q5ZkEJV82fRl5mWqzLnhCiCfpGQidSjysFmiIbuHsTcCNAeBx+pbsfMbNGYHPEsM9TwAcZEvQy8Q0GnQ17AjxR08iLbx6ibzDI+aVTue+GD/Cx82eQm6FhGRlfuRmprDh/BivOn0HvwCCv7m3jxTcP8uKbh3hu+0GSk4zFFQVcXRXq7Y/0TeJ4Fs3J2BRCJ2M/TKinvgn4tLtvj2hTBLS5e9DM7gMG3f2r4RO5fwA+4u4BM/sRUOPuD53q9XQydmIZDDr/9loj31+7hwPtxyjISuOGRTO5qbqU+dPf/Y1LkVgLBkNDii+8eZAXth9iT3Po1OC5Z+VyTTj0F8zIjbthxTOadRN+go8C3yM0vfIRd7/PzFYTCu2nzewTwP2EhnR+C/y5u/eGH3s18B1CQ0CvAXe6e9+pXktBPzG4O+t3B/jGMzvZdaiDC2bl8WdXzObD55boS0syqdS1dJ3o6dfsO4w7zMybEjqZW1XCxZUFcfFt3TMO+vGkoI+9Nw4c4W+f2cHGt1opL8zkK8vnc90HpsddD0gST0tnLy/taOaFNw+xYU+A3oEgU6ekctX8aVy7oISr5k/ejoyCXqLS0NbNd17YxVNbmijISuMvrjqHTy8pn7S/+CKn0903wIY9Lbyw/RAv7TzE4e5+puWkc+sHy7llcRnFOZPrS1oKejmtI939PLS+lkdfrscM/vTySv5s6dk6wSoJY2AwyIY9LTy6sZ7f7A6QlpzEH51/FndcWsl5pVNjXV5UzvQLUxKnegcGeWzjPh5cV8vRnn4+cWEpX7xmLmdNnRLr0kTGVUpyElfOn8aV86fxVqCTxzbW8+Rrjfz7Hw5wUXk+t19awfIPTJ+0Y/nq0SegwaDzi80H+O6LuznQfoylc4u557r5nHuWZtGIHHe0p58naxr58Sv17GvtpiQ3nVuXlHPLkrIJee0dDd0IEJpJ8/z2gzzwwm5qmzs5b+ZUvrJ8Ph+aUxTr0kQmrGDQWb+7mR+9XM+GPS2kJSfx8UUzuPvaeae8YmcsaOgmwbk7v6tt4dvP72Jb4xHOmZbND269kGsXaCaNyEiSkoyr5odm5NQ2d/DoxnrWbGrk2TcO8pXl8/n04rIJ/21w9ejj3Gv7DvPt53fy6t42ZuZN4a+unssNi2bqIlAiZ2BvoJO/eeoNNr7VyqKyPO77+HnD3rJxPGnoJgHtePso33lhF7/e0UxRdhqrrjyHW5aU6ZZ7IqPEPXSu677/3EH7sX7+22UV/OVH5sbspikaukkg+1q7+LsXd/P01iay01P
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = torch.linspace(0.05,0.95,29)\n",
"accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]\n",
"plt.plot(xs,accs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Assemble the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.BIWI_HEAD_POSE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#50) [Path('13.obj'),Path('07.obj'),Path('06.obj'),Path('13'),Path('10'),Path('02'),Path('11'),Path('01'),Path('20.obj'),Path('17')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1000) [Path('01/frame_00281_pose.txt'),Path('01/frame_00078_pose.txt'),Path('01/frame_00349_rgb.jpg'),Path('01/frame_00304_pose.txt'),Path('01/frame_00207_pose.txt'),Path('01/frame_00116_rgb.jpg'),Path('01/frame_00084_rgb.jpg'),Path('01/frame_00070_rgb.jpg'),Path('01/frame_00125_pose.txt'),Path('01/frame_00324_rgb.jpg')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(path/'01').ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('13/frame_00349_pose.txt')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img_files = get_image_files(path)\n",
"def img2pose(x): return Path(f'{str(x)[:-7]}pose.txt')\n",
"img2pose(img_files[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(480, 640)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im = PILImage.create(img_files[0])\n",
"im.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAAB4CAIAAAD6wG44AABeoUlEQVR4nM39Z7BlWXYeiH1r7X3Ovff5NGWybJfpalPVvhtoAAREAAQJkEESHMyMFCETChlKox8jTWgipDGKoRgxipAoaTQSRxqGZgIjgAIJAgPTIEiQMI1GO7Svru7q8iazKrPSvjTP3HvO3mt9+rH3Ofe+l1mFaoLEaHf2q/vuu+acvfZy33LyC7/wi08++UESAEyogBBCUAR3XhwfCTj+SoACUAQiXL6XwuMfML5VUH/ijq+5w/MCQCCuEIB00gEzc7ecs5m7q5tnW9Bp7m5ubu5wN3cjzVycRmZ390xASCfhgNNBI5xGEXF3gG4QEZIgAXD8Wf5XL4n0+jxhJEiSBHz1LUQm4SaAAyTFncxwz+55Oolb2xtmzvqpxuWS4QPLdwEUUgB3J1m/pXyPk6SBAsibb56PT33oyY999GNvs7///7W8/DO4wbJlMzNLvaeUcs45J2fOKeec3c0ysrlZNjMxd2Z3M88mnj3TM0Ca0unuZi4uAEkzrzsL0tz6bj/GAMLdgehGgmYGpkK6csxIujtIUgm6O+lEBsrzhRUchAtyNsCdmZ6dBqLv8smT933wyfd2XWe5nHoZyTs8XiEkhVTSClXdxz/R3YByGhWQ6HAAIyuVRwR02NblH4bfOfwzug+3mCGE17uFYjjiMRsIp9MBAV3cAYIuZiRgxkRzd3eambuXTXf3nHurD3K2hbubFZLRzc0speS5vtjdnMk9uTmBQlF3dzOYk2VPzSSDbp5JhwlZLp8KFzphAMwdAARuRoMUYUZSrOwH6cJcBAwJEaxwG0WGP0DLf1m4CjqSoFxt9uyeJPFw4Wvb6+VkuJOsBAbgThlojbrZAHU8WOXSyttWDkS5ZI1nX9vbXLu06JObZ8/ZzehOd8tFrLh7SC5FkjrNikikuzOzPDIzg5FGpuw91AbJBJQz4EXyBFLcHBDCBg6gey4vdDo9eeUhmOdypJw+iLUisJbiC8zjIxBEcnMyAl42ViCAkk6Y1OPpThOhwMvug4BHACLMOS36rmma2DTuTnMVqZJYtDKKUwBVqdJ75ACS4gBFREWJ4Z2EiIIBIIWiLhTFRKUx1RDpk9xOGzCAUQRSqUsAKg5ZaofVVU5eeV6kXswgsyGiAolf/OIfvfHGa+5ZxJVA2QiBFz1UWNU5fJ+UN7P+sMLg5UwVuQRQtB9IKoCYu2UjKaYQApni4PLiXAkRyxRBBCEZgJsSTWGIwiOAi5J0qRLUAYGqSJWTImFgKdcipFRFRcSLUUFAtREhaRpEQAhUVSUAAWAIYp7XZvFgvohxBlGliIiqCqDiIqpQiAohAhWFALGweQSkCNfyyaACEFURAYVwIBPmJiw63z3RxFPyvHP6brgrKCw7uyRjPTFauLacGS3HxgeJiyJmxOEUGQ0cxqsXX2lwoEG6rmvjBKB5FsAdouXEu4oCUBGIAFoOi6qIZJEg0qgERVIVVQUAJ8BQXwOoShNEVFQACgRSNqeyBVWbJprnxWK+1k6l2mVKiUC9DA0+mHAQiSvHOIgAQsBVI1hunoCV6yxGkIgKAkREivlIgBBKOeiAoByUDDCqd10O7Uy1Eak3DMCp5b8iLtByzEWg4lUkU1yIepjcE929aRpVzUWCIhFuGXDSzLIlJywny21cL19EIcgiLW5j3OF+CJRLrxzog0EnAoV41auCuL25c9epeyaT6eHhfDJ1laAaBCIKEQmqRdqUB2VzZVzqAhFRQARp/P5ymurpBirVAY1VRko9lfUNveW2bQ8O5jGGtgkCFUAURODyzrgUkuXzVekkFJVmvlj0McS2nZJ0OKsRXl8OCgkiFbkCYDz9EAkkgMPDfVWoMNJiuxabVrEUj47CkUVaDHK5XAEddIAq1cyFQwNDCCpCstACjCAhRqHRjTAz6/t5lydr5e54R/E72OxVdFbVsHIbQCGqimi1KlRFELdOTrdOTaeT9elGszGNKhq0FVERikgIKqpBRAeeFSnkL79HDRS1GzeuNe3OqdOnyx9Eo4hWcVWlJAQSNapCA0WooR4Sgs8++8319enpk3dfvHj5ve99dHPrBEhRI8q2V4YWiBcbTc1obk7Sqv3oznzh/MXZ2trm5pZlz06vBoI7nTQyEyTNHaA4QbequQfbjxAJQYTKEGIjyyMJgEESIKASUY8oxMI0BrhU+cOinQQC1aLIBFIEN4QQh4iTzmTe55yJQZfXV1YyK5QwKbsJFcAFgChkMJDru1CldyENCwvGJx5/REU+/KEPvPLKa/c9dP9DDz0ESAxRJKpqiFFFgpqIaFBVrbIeIgJhECE0X716OcatEydOFQGC73PlnHfuOr3Yb0/dN3/skfs0xju9iu4sXm3KbmY5JSdzzu5u2bLlja104sSJ8oHB3MyKeVls7EJuulGcxbAXrVYoSOZARg3aRBUR5hijqooDKsUSKzoeo2RYCk8HWbQRxQMFDCakCAUaAohAdboW1hZ30EVENUCdQeDQ4AgOdTiEItXzBglvBkblwKzlOkSKqhcBgpMCGaxGFUawiY8+8r7HHnv89Om7z5x5eOf0dtO0y1N5dH8Hvb36nAAk4t13PwTE0ZL/vqhL4u57Ht7Y2bzu8xP3ntAQbzcXyxIpkgWqARBSlA6ou4koRENommYCQCSomqoVu8/dyVh8MHBJdSAsHQxoJEVDCI0CpgghqobBIABJrD6u2AVQ9mWJ7VhheRW6VJkHUKCQcjYG8Tio9vIRWrShhvpRFTkaaHp0H4Y3wb0oC1ndtKLIRFREY2xm9565n8DOiZ2maVdet3wDIYPQOIJvsX5LJAs09S+6GI0iQR3+zuejGJNKARAj3AWAGQCQVA0hRJIhVGGlWihoZla3hnVTRKQ4gaqFj1Xp5ROEEgQhNAIR9XFPijnG6lphpA9ZDFCpbrBAhGQGVDWoBrqDo1qVVSE3+p+AAAEuAh0/fnBeVok3wBXlg5YEEdWlc1UPBxglKIKAWUL1ke+0r7b09Tj+RypyyDu95d0vQfEdUQzdd35t9VlALLmgUEi1HNg4oH1RRK1CHCOnVZdiRIUGqtCrzyExNDQXlRAiCaWMeMJyM1VlVURzIJLooCwNKJJGRRQqgBXHWxRC0WrQDDa6iqqsuADHCHGnQy8iEF3BQI7+pb49FmuZDKJyZ+oCQr39OWApsd6JJn/iIqgMoArqBrzj56mKkypwsBg1hcDFJCyPSSkAkyrcRTUWlFNE6QZxd6seaxEC7qG8KRZBCdOgIZgVt3QASIraEhAUuoxCOggAWiFqucwg4mQejBJCHKw/i7UsEkTcJbgKlCoo4h1Vxo7YA0WdjopfQkeou4gQ0ovvZ2bFDigQhCogiMWHERH8Sdzzr27F2Hxfrx8d8SIPC3Vl9J14/FCPMrlyjEvV4gBE6F4OAqgxRFEFXVUtm4amms3uKgXfHxh3cGAwyodC2+X3FpNBB+eHIgWmuMP5JakqIhStJlSNLCyl+vh4+fx4MUtvqt54OWUKyJI13zZ69K94iWAyaTlYKW9nYd3hfSKqGkIYz+j44O2Wqo6ysT6zfBhFo2qTkwHIOR0cHGhFalZVHVY/qt7A6mWPFpCISBAJpJBCqlAFQRDGr6yeZBGp1YH2AZ6quzHoY2WRGkcvpqwRaZAK+6kggFoFipn9KaXsn2aJSN/33aITkRJp+BNP2sCuyzVSLoSw8prj1H1bwosWxU6y6/r6ZAiqcvu3HLuSdzqRlOU/KCQUkh9VhZVmooXRizteAe3hPgvZ6oF8FzwgI5gThaA5itn53wgTs7juSjG3QCFbEFC+vfUmvgLcYNyIIg9Vg3vZqSNkRqVHfayq5hRQRUu0QQMUDtKhhoLPFeRPBhkr1UoGWHHvChPKCrS0lNEKwlVjjRZXCw8CBZUuoI6SPGiEqEigYwg3+ICKKEkW49pEi1NEQUG4SpASgfUBvV6FATkW/wFAAeyPndA/m1VsWnM6Qf0XsQZGEq4KK1nRyreLNQADaYsagwCpT4ICXI/vGKXind0VqB5xVVcel+9VLX5qfdKXz6uIQA0
"text/plain": [
"<PIL.Image.Image image mode=RGB size=160x120 at 0x7FA45C869B10>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im.to_thumb(160)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)\n",
"def get_ctr(f):\n",
" ctr = np.genfromtxt(img2pose(f), skip_header=3)\n",
" c1 = ctr[0] * cal[0][0]/ctr[2] + cal[0][2]\n",
" c2 = ctr[1] * cal[1][1]/ctr[2] + cal[1][2]\n",
" return tensor([c1,c2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([384.6370, 259.4787])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_ctr(img_files[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"biwi = DataBlock(\n",
" blocks=(ImageBlock, PointBlock),\n",
" get_items=get_image_files,\n",
" get_y=get_ctr,\n",
" splitter=FuncSplitter(lambda o: o.parent.name=='13'),\n",
" batch_tfms=[*aug_transforms(size=(240,320)), \n",
" Normalize.from_stats(*imagenet_stats)]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAFUCAYAAABPx8fsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZBlS17f9/ll5jl3qaqurn6v3z4bszEwZp9FWBIzaMTMgIURQgoZcODABqzwgnFIHogQAUiyUEgOAtthESySQFhIVngJBwjMogFbYh8JAcMyC8w+89Zeq+ree05m/vzHL/Occ29113tv3pvXENG/iOrquvesmb/8Ld/fkqKq3KW7dJfu0l26S3fpLLk7/QB36S7dpbt0l+7SH1W6qyTv0l26S3fpLt2l29BdJXmX7tJdukt36S7dhu4qybt0l+7SXbpLd+k2dFdJ3qW7dJfu0l26S7ehu0ryLt2lu3SX7tJdug2F876MmlQ0gyrgUVW890BCNaNkck4Epzz16KNomLE3n/PoY09y/epTPPH445yenhKzICI451BVYoyoKv2mo+96Nt2GLiXEKdeuXCPGSIwR5zxN8LSzhhBmNO0c1ySQgGogZVANOJlzsLfk6OgSBwcXcC2AAkrsO4SM9562CbRtIKXEBz/4Af7bv/otfMs3fwuve93r6bqOvu/p+54UBRGIsWOzOeXatSfo+zU5ZXLsQSPg0KzkmNh0PSlGgmtpZg3zeUtohaYNNM2M2WxB27YgAggqmeAdIkJWIAsKJM2QldRvSH1kHTtiSjTtjIfue5BNH0HsvgLU8h1FJv9Xu0edNs0AZPsD1XKeqs1htmf6O9/93Xzrt30biCLOIwJf9/X/qTzvHPcM6Hv/3ncrCLiMiPEOgJMGkYyjRVwuR+eBrwDEO7xmZk3g0es3ufrEx5m3Lb5p2VvM8K4hhBZ1xo92UoOIgnR0XcesXQIewSGSAOMmEETcMPYiYuM9qaJSTSAZlYZ+syZuVoR2DoD3HhGhlzTMgeBxWcr8QM6Tec22vnLOpJSQnMgKWSFlhdSTcwa191dVUrLnxQVUMzElUEEVUkqoQtdtUE0o9n3XRWLMiCghBEhrcnZobBDv6HOPE0fOmRAczgmz2YzQOhpmOOdo25amaYkIIQS897jWkUmIszHDtcbzGTuHtoyZzW+fI6qOb//2b37B+e7v/d2/reJsDIznPKIB5xjkhz2rDrIMGHhTNTFfNPyr/+8Xeerxj+Ik4P2C+SwQghAaD661cfEeEQcooXXklHBuhogjpVy+czRty2K+Tztv2d/fw7mAD0uCDzjnEAchBOMBlwHFeUGc0M4a1us1P//OX+D33/37vOc97+E1r3kNb3nrW/mcz/18Lly4CGRCCIgEkylqyy5nJRVed0UWADgXaJo5znnaZkbbzvAexHmapmW5WJJioo+RPidyghvXTtlsEserDavTFaenK9arFZvVKTF2rK9fw3tQerJLKD3iMsGZ3BYosuwAW2iZ6DpSiuScoFNyNl2k2tN3mZwhxshms6HvNzjn6COIOGJasV71BN+wXt8khBYh8E//+Y/clufOVZKSIs5jAlXCsFi7/gbdpuOxjz3GjSee5KMf+bAtBpmjAtmBaCKlDAjrbk2Mia7ryVnZbNbEGFmdnLLZbBBguVwy3/csL8yYzQ7xLpDSHCcLUk4ojsVij9mi4cEHHubw8JC9/T1TZOsNIhmRRAiepm3wTohxQ86Jv/k3/hZd1/Ed3/GdpJTYbDruvXwf3/d934+IcOPmTZwI165d4+bN66xOrpO1R1NCEJKeQvbEDmIX6eOK4Bva2QXmiwX7hyYA57Mlzntms5kJa+8QBHEO5xxZlay2EFHICE6ULJhApugGF8hOEVEcNm6I6VhFUQHZLW8VKVpRhoVsi7d8XVis/h6m32VyUh546CV2XdHJkXeGtNxb8JA9uADSlW89uFQMDkXE29gUAwBMcIkTRhOiDp6MSrf8FkDFmVGUM84LTgTI5XhfHkrLOUVoZDVhVj63CcnDuSoBDQ0pRsR5uq5nsdwDBOdyUYRq864BgSLsTDiZAaNFqeTy5nZOigkl4kNA1QTqoFiL8s05ISJ2bFGgms2E6roOM5gSKUdiEprQsNlEvPek2A3jFGM0IaRKzoqkRMomkLMmUi/knDg5OSamRFKIKZpB2UdUM32MOIHQzFCU4APihGZmCjaEQNM0+DAaey88KWTB+4AtQylGjM2r6mSuyzir6mhoYV8//NIX03UrWh9I4pi3M0Dx3qHOEUKDOE/bzBERQuNo28De8hJt2zCbzZnOpxMPTvEeHn/icf71L76T9733vVy9epVXvfpV/Ok//af4zM/8TJJiRnvKaIYuJbzM+TNvehtvefNbydnWh7Sek9MVJ6cbu39oAY9mxfkEBMiePkWcCDH2xN6xWm1IKbI6PSXFjGoixp4sPaoZkYzXhpwj3juStsXA3OC94og4h30v9pzGRo6uE3LK9Jrp+sRm3ZHTmtRHcgIRz7rr7bopIn2PqhKK4WK846FxNE2DKnjvCC7TLmc4cYR2Qc4KMicwI6P06YDlYp8Q5udyxrlK8id/4idYr1ZoKotaARGi2iLpN5GkkGNPjsq6v0bf9fSbnpwiJ6cnBB9oZgcsli3LvRmhDTTzGYjn8GiJ5kDbLpkt9rjv/vu5dOmIh150H8uFI9PR96d4t0ScEoKwnC9AE5ozq9VNVqeOt73tK3n7l34p3/RN34SIsNr0bDY9V64e0206vuGb/ktSznz4o4+yWt3g6tUrrFYrtItAQvMaTYkUIzklurgmpsysmeF94OLRPl1MNIcLZmHBbD4v1nIwq5CilMRtKyId/8pl7EQcOSveCU6LKnK24DyBJBkciKseXU/sOrQKdwVxDpJ5WbtKcFjw9cNqBqIULWgKpSx+SCDwFV/x5cbsbiL87zA5aUEy4iKCHxX/cERRpuJAwAuIelQSSRzee5Q1sIcTUCeoDyTvCQKuKFPTsorLjSlM5zGv0Q93c0L5O6GacMEWoxNXPEpQfHkmU4QxBdS1NGFO10fadk5WRxi8FSn3Mc/EOSG7yd/SInh8CKBq3ptTrj35OM7BvQ8+iOAwwyiRUhyeV2M2ozaaMDNlZ4Zr3/U2+7mgJ11msWhZrW/S9ZGDCw+SUsY5JcdsAl4xnsgCkjEjQknJDQo4q6LZPNaUEhEPouQEOdu4AfRFyLmifOvxvSZyujN8p2prQovXXb1++1IGDz2EMCwrM2bq82b6vuf+B1/M4cH9BBzqIbiG973/ffzar/0q7//DP2Rvb4+Xv/zTeOMb/gQPPvgArhjDOTEgIs6JCXSFPmdy7skaWcyWfNGfehN/6gv/JDFGUlT6vuP33/0eYjbDJ0ZDCVBDElQzaKCantmZIeecH+ZUKPPrbF5FdJBjJmoiqaAWsUvEmNCUyVmZuzk5mzfX44mxt7HTTNM0iAg+BNbdCffef5mUHE1xBMDRttk838YzJ8AiIIczQriEcw7vg6EtwQ+GX5iZ8emc+ZqCrU+ixzlQ0uAoVAPP+YCIp+9XHK8jTeNQnfHEE49z7+WXnssb5yrJ61evkVXpu46UdGDutO7JMXF884RN3NDnRDObs394kSbMOTo6Qrzy0NyBCJmA5hZyw/7eES996au4/757Oby0R7OAZqGI9OV1QdSheDQtkPmCJiQ0wfd8z/fyD/7BD/HOd76Tw4uHNG1L1sBP/NTPcnzzmI9//ApJhZwS627NtWtXuXnzBjeuPkGM5m1qNmWomqHLxNTTdWsUTzufMVu0LJt9Zs2cZtbQtA17ywV9n3BhZjCpk1FoACJ1GKvXYhxW14/DIZpBFM2pwAvFixCQ6jGo4qB4ECAEUAfSF2F+1sMbHcgyeqqmrLUoxUGlOKbqxc71pnQFLlw4pGCxhS36cxnnU03bSmSENM37k5FXJuOhgw9q7+RcUaBgylZk8lPv4QwOFCVnR8rJvMECedXxredU5enEk3V8ThUtsKEMx4SQEbeiaVs4tVCF14bsEyA47xAcTvzkmTzOiR3rgwm
"text/plain": [
"<Figure size 576x432 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dls = biwi.dataloaders(path)\n",
"dls.show_batch(max_n=9, figsize=(8,6))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3, 240, 320]), torch.Size([64, 1, 2]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = dls.one_batch()\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0111, 0.1810]], device='cuda:5')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training a model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18, y_range=(-1,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sigmoid_range(x, lo, hi): return torch.sigmoid(x) * (hi-lo) + lo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9b3/8deHsG8JS4BAgIBQFgUDjIjbdbcoXlBrvVgXFC3Vq9XbWrfqbXu1ttQuLm2tUjfqRl1qpXUr7tUqEPYdIksIISQQkgCBhCSf3x8z+ktjAoQZcjKZ9/PxyCNzvud7Jm9lZj5zzvec8zV3R0REEleLoAOIiEiwVAhERBKcCoGISIJTIRARSXAqBCIiCa5l0AEOR/fu3T0jIyPoGCIicWXBggXb3T21dntcFoKMjAyysrKCjiEiElfMbFNd7To0JCKS4GJSCMzsSTMrMLPl9aw3M3vYzLLNbKmZja6xboqZrYv8TIlFHhEROXSx2iN4Ghh/gPXnAoMjP9OAPwCYWVfgx8DxwFjgx2bWJUaZRETkEMSkELj7R0DRAbpMAv7kYZ8BKWaWBnwdmOPuRe6+E5jDgQuKiIjEWGONEfQBNtdYzo201dcuIiKNpLEKgdXR5gdo/+oTmE0zsywzyyosLIxpOBGRRNZYhSAX6FtjOR3IO0D7V7j7DHcPuXsoNfUrp8GKiMhhaqzrCGYDN5rZLMIDwyXuvtXM3gZ+VmOA+BzgzkbKJCLSZFVUVlOwax/bSvdRUFoe/r2rnO+cehTJ7VrF9G/FpBCY2QvAaUB3M8slfCZQKwB3fxR4AzgPyAbKgKsj64rM7F5gfuSp7nH3Aw06i4g0CyV797O5qIzcnXvJ3VnGluK9bC3eR17JXvKK97FjTzm1p4tJamFcMKpPzAuBxePENKFQyHVlsYg0dbv27efzwj1s2L6bDdvL2LB9D5t27CGnqIzisv3/1rd96yT6pLQjLaUdvZPb0iu5Lb06t6Vnclt6dmpLj85t6Nq+NS1a1DW0emjMbIG7h2q3x+UtJkREmpLd5ZWsyS9lTf5u1m7bxbqCXWQX7GZbafmXfVoY9OnSjoxuHZgwIo3+3drTt0t7+nZtT3qXdiS3a4XZ4X/IR0OFQESkAbbvLmf5lhJW5JWyfEsJK7eWsmlH2Zfr27dOYnCPjpw8KJVBPTpyVGoHBqZ2pG/XdrRpmRRg8vqpEIiI1KO8sorlW0pZuGkni3OLWZxTzJbivV+uz+jWnqN7d+bi0ekMS+vMkF6d6JPSLqrDN0FQIRARidhdXsmCTTuZu34H8zYUsXRLCRWV1QD0SWlHZr8UrjoxgxHpyQzv3ZnObWM7aBsUFQIRSVgVldUsytnJJ9nb+Th7O0tyS6iqdpJaGCP6JDPlhP6M6d+F0f270KNT26DjHjEqBCKSUPJL9vHe6gI+WFPAvz7fwe7ySloYjExP4bpTBzJuYDdG9+tChzaJ8/GYOP+lIpKQ3J2123bz9op83lm1jaW5JUD4UM/EzN6c+rVUxg3sFvNz8+OJCoGINDvuzqqtu3h9WR5vLstn/fY9mMGovincNn4IZw3ryeAeHQM7XbOpUSEQkWZjc1EZs5fk8ddFW1hXsJukFsa4gV2ZevIAzjm6Z7M+zh8NFQIRiWtlFZW8uSyflxfk8un6HQAcl9GFey84hvOO6UW3jm0CTtj0qRCISFxamVfK8/M28ddFeewur6R/t/bccvbXuHB0H9K7tA86XlxRIRCRuLG/qpo3l+fz9CcbWJhTTOuWLTh/RBqTx/bjuIwuOuZ/mFQIRKTJKy6r4Lm5Ofzp041sKy0no1t77p4wjIvHpJPSvnXQ8eKeCoGINFlbivfyxD83MGt+DmUVVZwyuDs/v2gEp32tR9zdxqEpUyEQkSZn0449PPL+57yyMBeAicf25tv/MZBhaZ0DTtY8qRCISJORs6OMB99dy2uL80hqYVw+rj/f/o+B9ElpF3S0Zk2FQEQCt610H799bx2z5m0mqYVx1YkZfOc/BtKjs877bwyxmqpyPPAQkAQ87u7Ta61/ADg9stge6OHuKZF1VcCyyLocd58Yi0wi0vTtLq/k0Q8+5/GP11NZ5Vw6th83njGInioAjSrqQmBmScDvgbOBXGC+mc1295Vf9HH379Xo/11gVI2n2OvumdHmEJH4UVXtzJqfwwNz1rJ9dwUTj+3ND84ZQr9uOv8/CLHYIxgLZLv7egAzmwVMAlbW0/9SwpPbi0gCytpYxI9eW8HKraUcl9GFx6ccR2bflKBjJbRYFII+wOYay7nA8XV1NLP+wADgvRrNbc0sC6gEprv7X+vZdhowDaBfv34xiC0ijalwVzk/f2MVf1m0hbTktvzuW6OYMCJNF4E1AbEoBHX9K3o9fScDL7t7VY22fu6eZ2YDgffMbJm7f/6VJ3SfAcwACIVC9T2/iDQx1dXOi1mb+dkbq9i3v5obTj+KG04fRPvWOlelqYjFv0Qu0LfGcjqQV0/fycANNRvcPS/ye72ZfUB4/OArhUBE4s/6wt3c8coy5m0s4vgBXbnvwhEM6tEx6FhSSywKwXxgsJkNALYQ/rD/Vu1OZjYE6AJ8WqOtC1Dm7uVm1h04Cbg/BplEJEBV1c6TH2/gV/9YQ9tWSdz/jZF8M5Suw0BNVNSFwN0rzexG4G3Cp48+6e4rzOweIMvdZ0e6XgrMcveah3WGAY+ZWTXQgvAYQX2DzCISBzZs38MtLy5mYU4xZw3ryc8uPEbXAzRx9u+fy/EhFAp5VlZW0DFEpAZ354V5m7n37ytp3bIFP5k4nAsy+2gvoAkxswXuHqrdrtEaEYla0Z4Kbn9lKXNWbuOkQd349Tcz6ZWsvYB4oUIgIlGZu34HN81axM49+7l7wjCmnjRAdwaNMyoEInJYqqudRz7I5jdz1tK/WweemHIcx/RJDjqWHAYVAhFpsOKyCm6etZgP1xYy8dje/OyiEXRso4+TeKV/ORFpkOVbSrju2QUUlJZz34XH8K2x/TQgHOdUCETkkP1lYS53/mUZXTu05sXrTtA9gpoJFQIROaiqauf+t1bz2EfrGTewK7/71mi6d2wTdCyJERUCETmgXfv2c/Osxby3uoArxvXnR/85nFZJLYKOJTGkQiAi9dpSvJepT80nu3A39046mitOyAg6khwBKgQiUqflW0qY+vR89lZUMfPqsZw8uHvQkeQIUSEQka94f3UBNzy/kC7tW/PM9cczpFenoCPJEaRCICL/5uUFudz+ylKG9urEU1cdpxvGJQAVAhH50mMffs7P31zNyYO68+gVY3SRWILQv7KI4O5MfzN8euj5I9P49SXH0qZlUtCxpJGoEIgkuOpq5+7XlvP83ByuGNefn0w8miTdNC6hqBCIJLDKqmpufXkpry7awvWnHcVtXx+i20UkIBUCkQRVUVnNTS8s4q0V+dz69SHccPqgoCNJQGJyeaCZjTezNWaWbWZ31LH+KjMrNLPFkZ9ra6ybYmbrIj9TYpFHRA6sorKaG55fyFsr8vnf84erCCS4qPcIzCwJ+D1wNpALzDez2XXMPfxnd7+x1rZdgR8DIcCBBZFtd0abS0TqVl5ZxQ3PLeSdVQX838SjmXJiRtCRJGCx2CMYC2S7+3p3rwBmAZMOcduvA3PcvSjy4T8HGB+DTCJSh4rK6i+LwD2TVAQkLBaFoA+wucZybqSttm+Y2VIze9nM+jZwW8xsmpllmVlWYWFhDGKLJJb9VdV894VwEbj3gmO4UvcNkohYFIK6TjHwWst/AzLcfSTwDjCzAduGG91nuHvI3UOpqamHHVYkEVVVO99/cQlvr9jGj84fzhXj+gcdSZqQWBSCXKBvjeV0IK9mB3ff4e7lkcU/AmMOdVsRiU51tXPHK0v525I8bh8/lKknDwg6kjQxsSgE84HBZjbAzFoDk4HZNTuYWVqNxYnAqsjjt4FzzKyLmXUBzom0iUgMuDv3/H0lLy3I5aYzB3P9aUcFHUmaoKjPGnL3SjO7kfAHeBLwpLuvMLN7gCx3nw3cZGYTgUqgCLgqsm2Rmd1LuJgA3OP
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(partial(sigmoid_range,lo=-1,hi=1), min=-4, max=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FlattenedLoss of MSELoss()"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.loss_func"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEKCAYAAAAFJbKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU1fn48c+TFQIJWwKEJBB2CFuQsLihuOIGtC7gVqm2Lq211dZf5VXbfr/YVqutS/1aW7Wgte4bgqK4gQqyBQlbIBD2kEASEsKa/fn9MTd0iCHJQG5mJnner9e8mHvuPWeeGZJ5cu859xxRVYwxxpjGCvF3AMYYY4KLJQ5jjDE+scRhjDHGJ5Y4jDHG+MQShzHGGJ9Y4jDGGOOTMH8H0BxiY2M1OTnZ32EYY0xQWbVqVaGqxtUubxWJIzk5mfT0dH+HYYwxQUVEdtZVbpeqjDHG+MQShzHGGJ9Y4jDGGOMTSxzGGGN8YonDGGOMTyxxGGOM8YklDmOMaYEOHC3no3V5lFVWNXnbljiMMaYFmrsml7te+Zbs/MNN3rYlDmOMaYHeXpVDSnwMQ3p0aPK2LXEYY0wLk7X3EGtzSrg2LdGV9i1xGGNMC/P2qt2EhwqTUxNcad8ShzHGtCAVVdW8tzqXCwZ1pXO7CFdewxKHMca0IF9tLqDwcBnXjEpy7TUscRhjTAvyVnoOse0jOH/gd2ZDbzKWOIwxpoUoOlLO55v2MSU1gfBQ977eLXEYY0wLMTdjDxVVytWj3BlNVcMShzHGtBBvpucwNCGGwfExrr6OJQ5jjGkB1u8pITPvIFPT3OsUr+Fq4hCRiSKSJSLZIvJAHfuni0iBiGQ4jx855RO8yjJEpFREpjj7XhSR7V77Ut18D8YYEwzeTN9NRFgIk0a4c++GN9fWHBeRUOAZ4GIgB1gpInNVNbPWoW+o6t3eBaq6EEh12ukMZAOfeB1yv6q+7VbsxhgTTEorqpizeg+XDe1Oh6hw11/PzTOOMUC2qm5T1XLgdWDyKbRzDfCRqh5t0uiMMaaF+CRzHwdLK7muGS5TgbuJIwHY7bWd45TVdrWIrBWRt0Wkrnc9DXitVtkfnTpPiEhkXS8uIreLSLqIpBcUFJzSGzDGmGDw5srdJHZqy5l9ujTL67mZOKSOMq21PQ9IVtXhwGfASyc0IBIPDAMWeBXPAAYBo4HOwK/renFVfU5V01Q1LS7OvRthjDHGn3YXHWXJ1kKuHZVESEhdX7tNz83EkQN4n0EkArneB6jqflUtczafB0bVauM64D1VrfCqk6ceZcBsPJfEjDGmVXp7VQ4AV49yv1O8hpuJYyXQX0R6i0gEnktOc70PcM4oakwCNtZq43pqXaaqqSMiAkwB1jdx3MYYExRUlXe+zeGcfrEkdopqttd1bVSVqlaKyN14LjOFArNUdYOIzATSVXUucI+ITAIqgSJgek19EUnGc8byZa2mXxGRODyXwjKAO916D8YYE8i+3VVMTvEx7rt4QLO+rmuJA0BV5wPza5X9zuv5DDx9FnXV3UEdnemqekHTRmmMMcFpzupc2oSHcMmQ7s36unbnuDHGBKGKqmo+XJfHRYO70T7S1XOA77DEYYwxQWjxlkKKjpS7tspffSxxGGNMEJqTsYcObcM5b0Dz325gicMYY4LM0fJKPtmwj8uHxRMR1vxf45Y4jDEmyHyauY9jFVVMSe3hl9e3xGGMMUHm/YxcenRow+jkzn55fUscxhgTRIqOlPPV5gKuGtGj2aYYqc0ShzHGBJEP1+ZSWa1MGdn8o6lqWOIwxpggMicjl4Hdol1fHrY+ljiMMSZI7Np/lFU7i5k80j+d4jUscRhjTJB4P2MPAJNGWOIwxhjTAFVlTsYexiR3btaZcOtiicMYY4LAhtyDbC044tdO8RqWOIwxJgjMWb2H8FDh8mHNOxNuXSxxGGNMgKuqVuauyeX8gV3pGBXh73AscRhjTKBbvm0/+YfKmOKHmXDr4mriEJGJIpIlItki8kAd+6eLSIGIZDiPH3ntq/Iqn+tV3ltElovIFhF5w1mW1hhjWqwFG/bSJjyECwZ19XcogIuJQ0RCgWeAy4AU4HoRSanj0DdUNdV5vOBVfsyrfJJX+Z+BJ1S1P1AM3ObWezDGGH9TVT7J3Me5/eNoGxHq73AAd884xgDZqrpNVcuB14HJp9OgiAhwAfC2U/QSMOW0ojTGmAC2fs9B8kpKuSSlm79DOc7NxJEA7PbazqGONcSBq0VkrYi8LSJJXuVtRCRdRJaJSE1y6AIcUNXKBto0xpgW4dPMvYQIXDi4dSSOuqZt1Frb84BkVR0OfIbnDKJGT1VNA24AnhSRvo1s0/PiIrc7iSe9oKDA9+iNMSYAfJK5j7TkznRuFzjduW4mjhzA+wwiEcj1PkBV96tqmbP5PDDKa1+u8+82YBEwEigEOopIzcrs32nTq/5zqpqmqmlxcc2/tKIxxpyunfuPsGnvoYC6TAXuJo6VQH9nFFQEMA2Y632AiMR7bU4CNjrlnUQk0nkeC5wNZKqqAguBa5w6twDvu/gejDHGbz7N3AfAJSn+v+nPW1jDh5waVa0UkbuBBUAoMEtVN4jITCBdVecC94jIJKASKAKmO9UHA/8UkWo8ye0RVc109v0aeF1E/gCsBv7l1nswxhh/+iRzH4O6R9Ozi3/npqrNtcQBoKrzgfm1yn7n9XwGMKOOet8Aw07S5jY8I7aMMabF2n+4jPQdRdw9oZ+/Q/kOu3PcGGMC0Oeb8qlWuGRIYF2mAkscxhgTkD5cm0dCx7YM6eG/lf5OxhKHMcYEmP2Hy1icXchVI3rgue85sFjiMMaYADN/XR5V1crkVP+u9HcyljiMMSbAzF2Ty4Bu7RnUPdrfodTJEocxxgSQPQeOsXJHMZMC9DIVWOIwxpiAMm+NZzKMSSMCdxo+SxzGGBNA3s/IJTWpY8Dd9OfNEocxxgSILfsOsTHvYMB2itewxGGMMQFi7ppcQgSuGB7f8MF+ZInDGGMCxPx1eYzr04Wu0W38HUq9LHEYY0wA2Ln/CFsLjnBxgE2hXhdLHMYYEwC+2JQPwAWDuvo5koZZ4jDGmADwxaZ8+sS1o1eXdv4OpUGWOIwxxs+OlFWyfFsRFwbB2QZY4jDGGL9bnF1IeVU1EyxxGGOMaYwvNuYTHRnG6OTO/g6lUVxNHCIyUUSyRCRbRB6oY/90ESkQkQzn8SOnPFVElorIBhFZKyJTveq8KCLbveqkuvkejDHGTarKwqx8xg+IIzw0OP6Wd23pWBEJBZ4BLgZygJUiMtdr7fAab6jq3bXKjgI/UNUtItIDWCUiC1T1gLP/flV9263YjTGmuWzIPUj+obKguUwF7p5xjAGyVXWbqpYDrwOTG1NRVTer6hbneS6QD8S5FqkxxvjJ5xvzEYHzBwbPV5ybiSMB2O21neOU1Xa1cznqbRFJqr1TRMYAEcBWr+I/OnWeEJHIul5cRG4XkXQRSS8oKDiNt2GMMe75IiufEYkdiW1f51dZQHIzcdQ1kbzW2p4HJKvqcOAz4KUTGhCJB14Gfqiq1U7xDGAQMBroDPy6rhdX1edUNU1V0+LigieTG2NajwNHy1mbc4AJA4PnMhW4mzhyAO8ziEQg1/sAVd2vqmXO5vPAqJp9IhIDfAg8qKrLvOrkqUcZMBvPJTFjjAk66TuKUYVxfYJjNFUNNxPHSqC/iPQWkQhgGjDX+wDnjKLGJGCjUx4BvAf8W1XfqquOeJbGmgKsd+0dGGOMi1buKCIiNIQRSR39HYpPXBtVpaqVInI3sAAIBWap6gYRmQmkq+pc4B4RmQRUAkXAdKf6dcB4oIuI1JRNV9UM4BURicNzKSwDuNOt92CMMW5asaOI4YkdaBMe6u9QfOJa4gBQ1fnA/Fplv/N6PgNPn0Xtev8B/nOSNi9o4jCNMabZHSuvYl1OCT8e38ffofgsOO42McaYFmb17mIqq5UxQXK3uDdLHMYY4wcrtxcjAmf06uTvUHxmicMYY/xg5Y4iBnWPoUPbcH+H4jNLHMYY08wqqqr5dlcxY5K
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.045840</td>\n",
" <td>0.012957</td>\n",
" <td>00:36</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.006369</td>\n",
" <td>0.001853</td>\n",
" <td>00:36</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.003000</td>\n",
" <td>0.000496</td>\n",
" <td>00:37</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.001963</td>\n",
" <td>0.000360</td>\n",
" <td>00:37</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.001584</td>\n",
" <td>0.000116</td>\n",
" <td>00:36</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lr = 2e-2\n",
"learn.fit_one_cycle(5, lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.01"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.sqrt(0.0001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAGzCAYAAACMxsRFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9ebBty13f9/l1r7WHM9x77n3z03tPM3pIQkxi1ICEBHJkEw9F7LIpk0o5phyIXR6CIdhxAimDHcr8kcIhxnZiOxAbbIwDxiAgGBmBZFkS6AkJ0Cw97hvvcO49w957re7fL39091pr77PPuc/E7x5VsX919z17r9Wrx2+v/k39azEzNrShDW1oQ3eG3HlXYEMb2tCGfj/R5qW7oQ1taEN3kDYv3Q1taEMbuoO0eeluaEMb2tAdpM1Ld0Mb2tCG7iBtXrob2tCGNnQHqTrrpprZjf0bXNq7BAIy9C4zQJZS54uFZCmB4dDBHScM0isQlws3P8gq5LRutdCct6FLZfuVBFVujxKjYQZN02BmNPOIGrRtIMZADNC2LapKjBEk0jQNbdtipph6zIwQAqaKV0UtIAINDVEDqiFVWx2qhqqmT2tghlokhJYYI6qKc44QWm7eeIrpdEqMkRgNoaZtW0IIhHCMWUpbngNQVcwMEUFEiDEiAuLafD9iBiJgFhHxmEHbRmKcEzUiYqn7zdOa461vfV3ujwjmEKm68gBijJhZul+62AwnFeAwi0t1K99VFRHBzBiNJhwdHfMt3/YXTg7oHaANtjfYPi9sn/nSDRbYvrBDsIAgeDyKJChkYAmSQBt81wFSirL+exOUGAOj0YgQIwtxeA+VA3AEdd2zIcB2qUMAJyNUIUbQmPJtmvS9aWHRNjQx0DQNi8UiD2zsBnW+OCTGyGI+T7+bfcwSYNp52w1oCAENqSNFJIOzAQlYwhTiEkBjVEIIVEgGDETX5HYrhuLEY2rdwNUORBRQRKANEe99KlcVT8XN69YNZpqQ+SttB8D0l76eMaZp6xwOiCEi1B0gAIK2CDUhHANGCEbUhhBagi6QYLStcCsEnPOIRLyrMbM8rpK/G96nySnOJTSY5UmbXkaqedwTVFJ7BMQJmAMz5vN5+n5OtMH2Btvnhe0zX7q/9aF9RqMRN2/epGka9o9u0rQtx/MZKqDNolsZfBNwHSJBqAghdJ1XeYGoHQgqV2MExLXE2CKT4y5toimQVmZE83hLXtVaDEU1plUQwTkjxrT6RF3kDkrPxTZ1ove+u5ZWd6NK/ZRXTcO0wUw7MIl4VGdoFCxzKGVlq6qKEPOzISKmoAmUoGD9qmlm3Fr0z8/nxxwcHbK3t8d4PKZtW2Qw6CCIVF16h+Cco6rKkKW+rCqPmVDVQgxGVVXUTrp+NGeYRWpfIYyweoRa4oQ0bqGqBFtQKYSoTLwkTszq9NJB6Dk9Q8Qw8sRQy8zZMmvoRFIaJE1Q51BTQDJIBbEqcSHnRBtsb7B9Xtg+86X7q7/6s4Qwx2WJptJUsagtVeVom9QRMUbavFKpal7henY7hICIomp5FU6dHULoxRBpu+dDCEk8ahpmsxmEKq243kACTqouLcAiizxtG9nZ2cHFgPiI4DF1mPPdSpoGXqnqBHJRwznXDaSvJxiZ5UBxrkoAMYdzDostIo66GuHM4Z2heJxPAPfeZVHHOnHJe5+eNUmDIUaMgRc8eA9PPf0s4scYQp2B6L1PICDgXZXqUAY911WqFnCd2FNEUJGY+9136c0MwWEELHMVph4NEdRYWMRZSxtaqCY4gcpZ4jLMEXPpPSeSy3ICmVsoYq+IoM4l8QswVyXYJraBnBgTBVkRu+8gbbC9wfZ5YfvMl+773v3LOG9cvHiRa9eusbN9AXFpdSurksur7LgeYWaMx+NcwTQYQE4bqKoKM4+TmrEL7GzXqBYd1ahrWOmAbvVz/W8QmqahqqpOZ+TqEWrG0dEtxpMxI8lA0ZxXBnLJ11dpNXMehHrQ4YKawzkQB2kVdGCCcxUhtDifJlxpm4ghUoOllW+4wqmGDkhmhsMjztCsGxtVwkOP7IJUVNUI09CBGBMiHiMiEnADXZ6ZUUmvh8KEYGnFdj6LXREWiwXTaeKqNIuKURvACK2BGu2iQUwgtogLROqliZolbLqVfKDPTGs+g/oWLka6CZF0YIaZ5HspjXMlv/OhDbY32D4vbJ/50n300VdS1zX33H0vN27cYjSeYypUVQKR99INeO18xxnUdZ1XPun0JN4PDA8maJxnoCbAWe6IMuDOpdU3KagN53xWpEserL7sw9kx3jmuXxcuXbrEdOxxMkrAE0OpsuJb6UUKV8Z0acCjOYpeSpyxmLeMRmOSEt0w0W5CmGleuQVTwWi7QeuUPuWXCN4SRJpmjlGBGZXzmNRUdZ3WdDOKSOOKc8lAtC05YyOwiGKIKI4aI6KmqBm1rxiN0jhFszS5AXGjpItzgRBaDEnGj2bOoom4UZVX+KFYNWhL7nPVXj9omsqUlXqmPELua4fksj4XaIPtDbZPtOUOYfvMl+49L9imbSOjHdh1NTv1NBkQqHNhoa9MbkzlfXrTq1taNSs/oaocSMOtg2vsXrqHe+65h6qqUhqp8d53YKuqqnt+Wk0A2N4dIU4Zj3pRRc14/PGP8+EPf4QXPfxCjo7mvOWtbwKtEGeIRGIWR4pF0izVTxXMBYJGFotF1lEBWacWteXatX3GozEiaWK0UTqxMVmADdU2r7JjYjSceGK0ztKLJeuyhYBzNSoNpkrlXdL/+BpxPg8eeQIZlbSJyzCf7hSci2HiMRHEYgKOI1mfSZxY6buyKkcRjCQGmijke0FAbYFakya+TTpASrEY5HL7FT7/Fpf1WC7XXLBhPSHrv8jXXNaJSTZErTXu3hHaYHuD7fPC9pkv3W9429v56Z/+Gf7w29/Oj//Yv+DN//nb2Nvb63RbWDUAZpsXjYFZN1csVSnpctRaQphT15c6sWeJlvXWJ3+vkDO4eO+L+PoXvZSPfPAZ3vDmz6OayFLexY5Y57/F3UNVmS8i3ozKmrTKV0nE1BAwDYhbMJrsJLHGRXCxM1qUPFRrYhxjGqgkXXPeUCliWtb/+ZbKjCieOPGIGuDyZPY4TX1XLKniUo2HrildmyQbR7Jog0ScA9MRkQjeYRYw59NarJIU/kJyA0JoM3g9dbrmQKoakwoloGhnTEiiU+YArMa7Uh/QzMFYGX8cgmXj004Gb2qTIVn95VlynbrDtMH2Btvnhe0zX7oX9+7jT3/zn0Ej/Mk/9c34SRKJKl91mOnxU3XyTFe/4f3coaqC867r8CVgFhCu/j2LBIQpVT3i4qUZro6dov006kU7oa7r7nux3MYYu5W0qkaMx8W/MHaGkGK1TTqsNFFN3dI98F2eZkbtayqFuXiqyhMWDZWv8b5Kz0clC0oUV5TST0WUXe4TG7wDHIZhrp+UQ26AgS5KnCDWi7CrfTP8C8l9qUwOAVT6KiRDRhpvl8U8ioiqRpbUuheUSOZ4Sr3OiTbY3mAbzgfbZ750q9EYM/B1WQlcViwzUEJnWuIATqhqOhA7X+VibanxSx1+O25geM3AVTUmDqlqKIpyOz09QjcoS3UoSbIuznvfWWiH4CjXgOToXXRBaAfSocN0J7pI0mQ55xFfY87wPhsazPCuBwBm2eDRcwFJX1Vuu6UxMCu+oEqMLVDjfbYGQxK7lgbDurZ2LjjZYi7ik1+iSQZWMhKUqokV1xrpdYC5fsn4IP04ZlHL5SlXSo4hZk7mfGiD7Q22zwvbZ75029gmgwArQFsDoBNFDFd04LY+mevqOHj+1LQCXsh+gBWxV9GfmXfnYsNQvLETQJ1MJgP/wkSFm0jPjHodGGQLroIo4l0HaDMDjZgk55adeoS0Aec9SNH3BYr+UIm4gkznIPaiHwBeEIOoARC8FJQmfSTZh7CISELEssgkCM5qnAPnlCABqWu8RsajasAdJQOBuMKR5B1CLmIKnUEJh1rqycQVlF0+gvcQQkTEkUTwrPdyltiKc6INtjfYPi9sn/nS7V1HZFWddftVevj
"text/plain": [
"<Figure size 432x576 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.show_results(ds_idx=1, max_n=3, figsize=(6,8))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. how could multi-label classification improve the usability of the bear classifier?\n",
"1. How do we encode the dependent variable in a multi-label classification problem?\n",
"1. How do you access the rows and columns of a DataFrame as if it was a matrix?\n",
"1. How do you get a column by name from a DataFrame?\n",
"1. What is the difference between a dataset and DataLoader?\n",
"1. What does a Datasets object normally contain?\n",
"1. What does a DataLoaders object normally contain?\n",
"1. What does lambda do in Python?\n",
"1. What are the methods to customise how the independent and dependent variables are created with the data block API?\n",
"1. Why is softmax not an appropriate output activation function when using a one hot encoded target?\n",
"1. Why is nll_loss not an appropriate loss function when using a one hot encoded target?\n",
"1. What is the difference between `nn.BCELoss` and `nn.BCEWithLogitsLoss`?\n",
"1. Why can't we use regular accuracy in a multi-label problem?\n",
"1. When is it okay to tune an hyper-parameter on the validation set?\n",
"1. How is `y_range` implemented in fastai? (See if you can implement it yourself and test it without peaking!)\n",
"1. What is a regression problem? What loss function should you use for such a problem?\n",
"1. What do you need to do to make sure the fastai library applies the same data augmentation to your inputs images and your target point coordinates?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Read a tutorial about pandas DataFrames and experiment with a few methods that look interesting to you. Have a look at the book website for recommended tutorials.\n",
"1. Retrain the bear classifier using multi-label classification. See if you can make it work effectively with images that don't contain any bears, including showing that information in the web application. Try an image with two different kinds of bears. Check whether the accuracy on the single label dataset is impacted using multi-label classification."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}