fastbook/clean/06_multicat.ipynb

1603 lines
620 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
{
"cell_type": "code",
"execution_count": 2,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
2020-09-03 22:51:00 +00:00
"from fastbook import *"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"# Other Computer Vision Problems"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Multi-Label Classification"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### The Data"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 3,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-08-21 19:36:27 +00:00
"from fastai.vision.all import *\n",
2020-03-06 18:19:03 +00:00
"path = untar_data(URLs.PASCAL_2007)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 4,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fname</th>\n",
" <th>labels</th>\n",
" <th>is_valid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>000005.jpg</td>\n",
" <td>chair</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>000007.jpg</td>\n",
" <td>car</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>000009.jpg</td>\n",
" <td>horse person</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>000012.jpg</td>\n",
" <td>car</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>000016.jpg</td>\n",
" <td>bicycle</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fname labels is_valid\n",
"0 000005.jpg chair True\n",
"1 000007.jpg car True\n",
"2 000009.jpg horse person True\n",
"3 000012.jpg car False\n",
"4 000016.jpg bicycle True"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 4,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'train.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: Pandas and DataFrames"
]
},
2020-04-28 17:12:59 +00:00
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 5,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 5,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[:,0]"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 6,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fname 000005.jpg\n",
"labels chair\n",
"is_valid True\n",
"Name: 0, dtype: object"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 6,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[0,:]\n",
2020-05-14 12:18:31 +00:00
"# Trailing :s are always optional (in numpy, pytorch, pandas, etc.),\n",
2020-03-06 18:19:03 +00:00
"# so this is equivalent:\n",
"df.iloc[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 7,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 7,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['fname']"
]
},
2020-04-23 13:41:55 +00:00
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 8,
2020-04-23 13:41:55 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
2020-04-28 17:12:59 +00:00
" <th>b</th>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
2020-04-28 17:12:59 +00:00
" <td>3</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
2020-04-28 17:12:59 +00:00
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a b\n",
"0 1 3\n",
"1 2 4"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 8,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tmp_df = pd.DataFrame({'a':[1,2], 'b':[3,4]})\n",
"tmp_df"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 9,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" <th>c</th>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </thead>\n",
" <tbody>\n",
2020-04-23 13:41:55 +00:00
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <th>0</th>\n",
" <td>1</td>\n",
2020-04-23 13:41:55 +00:00
" <td>3</td>\n",
2020-04-28 17:12:59 +00:00
" <td>4</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <th>1</th>\n",
" <td>2</td>\n",
2020-04-23 13:41:55 +00:00
" <td>4</td>\n",
2020-04-28 17:12:59 +00:00
" <td>6</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2020-04-28 17:12:59 +00:00
" a b c\n",
"0 1 3 4\n",
"1 2 4 6"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 9,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tmp_df['c'] = tmp_df['a']+tmp_df['b']\n",
"tmp_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Constructing a DataBlock"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 10,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock()"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 11,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"dsets = dblock.datasets(df)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 12,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(4009, 1002)"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 12,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"len(dsets.train),len(dsets.valid)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 13,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object,\n",
" fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object)"
2020-04-28 17:12:59 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 13,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dsets.train[0]\n",
"x,y"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 14,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"'008663.jpg'"
2020-04-28 17:12:59 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 14,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x['fname']"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 15,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"('005620.jpg', 'aeroplane')"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 15,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(get_x = lambda r: r['fname'], get_y = lambda r: r['labels'])\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 16,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"('002549.jpg', 'tvmonitor')"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 16,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return r['fname']\n",
"def get_y(r): return r['labels']\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 17,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(Path('/home/jhoward/.fastai/data/pascal_2007/train/002844.jpg'), ['train'])"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 17,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return path/'train'/r['fname']\n",
"def get_y(r): return r['labels'].split(' ')\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 18,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(PILImage mode=RGB size=500x375,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]))"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 18,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 19,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(#1) ['dog']"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 19,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = torch.where(dsets.train[0][1]==1.)[0]\n",
"dsets.train.vocab[idxs]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 20,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PILImage mode=RGB size=500x333,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 20,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def splitter(df):\n",
" train = df.index[~df['is_valid']].tolist()\n",
" valid = df.index[df['is_valid']].tolist()\n",
" return train,valid\n",
"\n",
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y)\n",
"\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 21,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y,\n",
" item_tfms = RandomResizedCrop(128, min_scale=0.35))\n",
"dls = dblock.dataloaders(df)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 22,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAACzCAYAAAD2UgRyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9edBuW37Xh31+a+3pGd7pvGe455479+1Wt7pbtNSSQQIMGIjLMVDYweWqODjGsUzFNhVSSdkuJwQn2BhIVXAcE0IZEhLHQyWEwglgYmNbVgQhRApIqNVq9aDbdz7TOz7DHtaQP9Za+9nPPs97zu1W9z2SeH+3zn2fZ49rr2ft9fv+vr9hifeea7mWa7mWa7mWa/n7W9TzbsC1XMu1XMu1XMu1PH+5BgTXci3Xci3Xci3Xcg0IruVaruVaruVaruUaEFzLtVzLtVzLtVwL14DgWq7lWq7lWq7lWrgGBNdyLddyLddyLdfCNSD4yCIi/7qIfO15t+Na/v4UEfkxEfmzT9n/50Xkr3+M7fnNIuJF5KWP657Xci3frjzr/bmWINnzbsC1XMu1fEfkf8A1wL+Wa7mWX4JcA4JruZZfBeK9P3/ebbiWa/lOiogU3vv2ebfj7ye5tih2iIiUIvKnReRcRE5F5E8D5WC/iMj/WES+ISKtiHxdRP7g6BrHIvJ/FZGliNwXkT8iIv/Hj5PWvZZfdaJE5I+JyCMRuRCRPysiE9jtMhCRf1JEfkpEahF5LCL/qYgcicjvE5EzEZmOjv/DIvKLIiLx+yfiGD4RkZWI/IyI/I6rGicib4rI/y1e+1RE/jMR+fx3oyOu5ZeXiMhvj7T8SZw3/2sR+QcG++ci8r8WkffiWPo7IvKPD/a/Fl1Q/5SI/FURWQJ/9CPOtW+JyL8Z34eL+H78cRG5Ur89q73xGC8i/4KI/Psiciki74jIvzw6Jovu5F+M79mXROT3/1L783nJNSDYLX8M+G8B/zTww8AS+BcH+/8F4I/E4z4L/C+BPyYi/73BMf8H4NcAvwP4h4CXgN/93W74tfyqlt8DHAO/EfingN8F/PFdB4rI7wP+z8BfAn4A+C3AXwM08B8DHvgnBscr4PcBf9Z770XkBeBvAkfxPp8H/hDgrrjfHeAngAexfb8O+ArwYyJy65fwzNfyK0PmwJ8i/O4/AnwV+GvRMBLg/0GYD/9J4HPAnwb+YxH5raPr/HHgPySMtz/FR5trAf4A8D7wQ8D/EPiXgD/47bR3dNwfBn4c+EK89x8Xkd8y2P9ngX8c+P3AZ4D/RTxm3L5fGeK9v/43+AfMgBr40dH2nwS+Fj+/A/yJ0f4/CXwjfv4kYcL9rYP9eTzvrz/vZ7z+9yvvH/BjwFuAHmz754Emjtk/PxxbwNvAv/uU6/07wE8Mvv/DQAfcjd//CPAhMLvi/N8cx/hL8fu/Dvyt0TECfB34g8+7/67/fbz/CMbmKQG4/uY4px6MjvnfA38pfn4tjqc/NDrmqXNt/P4W8P8aHfNHgXcH33+MAHaf2d7BNg/8O6Pjfh74t+Ln1wkA+dOjY/5nwN993r/Bt/PvmiF4Uj5BcA/8zdH2nwAQkX2Ctf/jo/3/NfBapGG/N277W2mn974jgIpruZZvV/62994Ovv8NoCCM2V5E5DbwMvCfPeVafwb49SKSxuqPAn/Fe/9B/P5F4G9675cfsW0/BHxRRBbpH3BJmOg/+RGvcS2/QkVEXo/U+tdE5AK4AA6AVwljowDeG42P/w5Pjo2/PbjmR5lrk/y/R8f8DeBevMa32t6h/N3R9/eAO/HzDxJA70+Onutf2/FcvyLkOqjwSZH491nLQI73y0c45lqu5Tspu8bcUK4cf977L4nITwD/nIj8MYJb4Hd/1PN3iAL+CwJVO5brgMdf/fKXgUcE1+o7QEswogrAEsbAD+04bxw0uAuAfpS5dizPOuZp7X1a+zwbV3v6+yPAasdxv+LkGhA8KV8jDIJfD/zcYPuPAHjvL0TkXeA3AX9lsP8fBH7Re78SkXTeDxMmSUQkI1hdv/Ddbf61/CqWHxIRPWAJfpgwVr8+PMh7/yCO0X+Y4Lu9Sv4M8G8DJwT3wF8b7Psp4EdFZPYRWYKfBP4Z4D3v/fojHH8tv0ok+t2/F/hveu//n3HbS8DteMhPAodA5b3/2Y963Y8y1w62/brR6T8MvO+9v/g22vtR5afi31e893/5Wzz3l6VcuwxGEie//x3wb4jI7xKR7xGRPwF8enDYvwX8ARH5URH5ZIwq/e8T/FZ4779KmIj/lIj8pkjL/hlgn1+hyPFaflnIMWFMfUZE/lGCn//fu0Jh/8+B3y8ifyge/1kR+ZdE5ObgmL8Q//4h4M9574cBg/9bwvzwn4jIr48U6+8QkX/kirb9u4SAxb8kIr8xRo3/hhj9/SO/lIe+ll/2cgo8JADIT4nIDwP/EZCA4X8J/HXgL4rIPyYib4jIF0XkD4jIjz7j2k+dawfyhRjt/ykR+W8T6nL8yW+zvR9JvPdfI8RB/Hsi8nslZNn8GhH5Z0XkX/lWrvXLRa4BwW75VwnR2f8+wad1SIhITfKnCYEj/xqBRfhXgH/Ve//nBsf8PuBngf+UENDyHvCfE4JrruVavh35CwS//E8QMgX+KvAv7zrQe/9nCRb77yH4QX8c+EcAMzimJozxDPhzo/M/AH5DvN9fBb4E/JtcQcV67+8TrLJHwF8kZBj8BwSf7Ae7zrmWXx0SgeQ/QYhl+RlCgOu/TfzdfYi0+12EcfG/IgTm/RXgH2XEbu2QjzLXAvxvCGPtJwng9E9zBSB4Vnu/Rfnn433+J7F9/wXw3wW+8W1c67mLxKjIa/kui4howovwf/fe/4+ed3uu5VoAROT/Aky897/zebflWq7l2xEReYuQQfBvPO+2/EqX6xiC75KIyD9I8En9HWCPkBv7GgGNXsu1PFcRkSNCvYB/DPjtz7k513It1/LLQK4BwXdPNPA/Bd4k5Hf/LPBbvPd/77m26lquJcjfIcQk/Anv/Y8957Zcy7Vcyy8DuXYZXMu1XMu1XMu1XMt1UOG1XMu1XMu1XMu1XAOCa7mWa7mWa7mWa+EZMQS/7tf9uif8CWGdiiDOuf5vcj1orZ84dlDjGaVUv2+4/Vmui1Gt6K1rpXYM7ysiOOf6tqXzhu1P17XW9p/TtdI51lqMMTjn6LqOruuw1vbXHn4ft+9atsV7/1EqjH1H5ff+W/+RV1pTViXeg+kM3nokU3hnKauMplmjlEYpjfeCVhnT2RTw1KsFuVYoLVhrqdc1lydnKAQ3KQGPUoISjXeKrqnBGHI0qq2pTz7ArheUZcXB3bvMbhxz89aLnJw9Qilhf/+QajLBOotWhOsB3lgynWG9o27WmK7j6MYhs/mcpu2YzedU1QQvitY4vBI8Hu8deZaB97Rdzbq5ZHF5yfn5GfXZOcvFAr+u0W1HhkLNpxSzA1rnMEChS/ZmM2azKXlZUVQleaEpqwk6y1FKgwim61gvlthlw+rDx3z1b/9XPDp9l9Z7dDllMp9QHd6gzeesyWhNi/MOxOOtpa2XzKYV0+mM1XpJW18iZo3Ua9rzS7q6gUz4zCd/gJu3v4cXXnmVo6NDFg/O+M//8l/k9PIdOJixd3zMjcmMsgNfd0wL8KWw6Fo8iq72KApUUSATmO3tgc1xrmF/MuPyYsXJ40sWZwsW9YJqJig8nXdonWO7jrZdoZVFZcJf+0/+wsc+hgH+8B/9814kQ3obzoe+jN9c+tT/2Z7vBCFljIrAplR/EtmaG4efh3OaiKCU2p5HxW+fO/g7vG+4jn9iDg6HCIiQdoW5XcZN3Nm+cRvH28O87knrcimlUaIQUYgS0juc9iXdtGnHRqf01x3cXimFFvXEMZu20j8XSlCiULLRk/25gBcJz5n0FbEP/aY7PJvfXQB8+P299zjv6DpLXXe0bYu1LVqp+E8jSiFK8y/+vt+6cxw/FRBkWYb3futBh3+HynwICp6m6J1z/TXGinx87qZDtwfquNOVUk+cswsAjAf5zsE9Ot851wOG4b4EAqy1dF3XfzbG9PuMMRhjnuiXXc9
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 648x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-03-19 13:21:55 +00:00
"dls.show_batch(nrows=1, ncols=3)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Binary Cross-Entropy"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 23,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 24,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 20])"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 24,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-09-03 22:51:00 +00:00
"x,y = to_cpu(dls.train.one_batch())\n",
2020-03-06 18:19:03 +00:00
"activs = learn.model(x)\n",
"activs.shape"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 25,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"tensor([ 0.7476, -1.1988, 4.5421, -1.5915, -0.6749, 0.0343, -2.4930, -0.8330, -0.3817, -1.4876, -0.1683, 2.1547, -3.4151, -1.1743, 0.1530, -1.6801, -2.3067, 0.7063, -1.3358, -0.3715],\n",
" grad_fn=<SelectBackward>)"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 25,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"activs[0]"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 26,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(inputs, targets):\n",
" inputs = inputs.sigmoid()\n",
2020-04-22 15:14:42 +00:00
" return -torch.where(targets==1, inputs, 1-inputs).log().mean()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 27,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"tensor(1.0342, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 27,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_func = nn.BCEWithLogitsLoss()\n",
"loss = loss_func(activs, y)\n",
"loss"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 28,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Hello Jeremy.', 'Ahoy! Jeremy.')"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 28,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def say_hello(name, say_what=\"Hello\"): return f\"{say_what} {name}.\"\n",
"say_hello('Jeremy'),say_hello('Jeremy', 'Ahoy!')"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 29,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Bonjour Jeremy.', 'Bonjour Sylvain.')"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 29,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = partial(say_hello, say_what=\"Bonjour\")\n",
"f(\"Jeremy\"),f(\"Sylvain\")"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 30,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.942663</td>\n",
" <td>0.703737</td>\n",
" <td>0.233307</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.823155</td>\n",
" <td>0.555462</td>\n",
" <td>0.298347</td>\n",
" <td>00:06</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.606124</td>\n",
" <td>0.202830</td>\n",
" <td>0.815060</td>\n",
" <td>00:06</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.360787</td>\n",
" <td>0.123490</td>\n",
" <td>0.942052</td>\n",
" <td>00:06</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.134407</td>\n",
" <td>0.118581</td>\n",
" <td>0.949661</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.117051</td>\n",
" <td>0.104169</td>\n",
" <td>0.950657</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.097517</td>\n",
" <td>0.101461</td>\n",
" <td>0.952789</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
"learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 31,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(#2) [0.10146083682775497,0.9298606514930725]"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 31,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.1)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 32,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"(#2) [0.10146083682775497,0.943486213684082]"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 32,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.99)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 33,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"preds,targs = learn.get_preds()"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 34,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"tensor(0.9575)"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 34,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_multi(preds, targs, thresh=0.9, sigmoid=False)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 35,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD7CAYAAABt0P8jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAk00lEQVR4nO3deXTc5X3v8fdX+y5Llizbkm15ExibGGxhDL5ZyE5TAqmzmkJaSJ0DodC9bi++pNykS5pzckILtOamDSEsaVpIyEJwE0gJmMU2YIhYvMkbWLZGkmVptM987x8zcmQxxiMz0kjz+7zOmSPPM8/85js/S5955vk98xtzd0REJDiy0l2AiIhMLAW/iEjAKPhFRAJGwS8iEjAKfhGRgMlJdwHJqKqq8vr6+nSXISIypWzfvj3k7tWj26dE8NfX17Nt27Z0lyEiMqWY2f5E7ZrqEREJGAW/iEjAKPhFRAJGwS8iEjAKfhGRgFHwi4gEjIJfRCRgpsQ6fpHx5O60dvdzsL2XQx09tHb1E4k6UQfHcY/1cectbbnZWUwvyaeqJI+q0nyqS/KpKsmnMC873U9L5JQU/BII7eEBDrb3cLCjh0MdvRxsj/08FL/ePxRN6eMV52VTVRp7EagqyaOqJJ/SglzMwIAss9i/zTDALN4GZGUZOVlGfk4W+bnZsZ858Z+5v/l3QW42eTlZZBmxFyIg6iNeqCD+YhVvw8nPyaYwL5ui3NjP/JwszCylz10mPwW/ZJTu/iF2HuliZ0sXr7V0sfNIF6+3dNEWHjip37SiXOZUFNFQU8r7z57BnMoi6ioKmVNRxIyyAnKzDWM4nMEwskYFtZkxMBSlLdxPqGuAUHc/rd39hLp/cz3U3U9zKMzWfR109w2deLcQHRHM6ZRlUJSXQ0FuNkV5sUthXjbFeTnMLC9gTkV8v8T3T01ZAdlZeqGY6hT8MiX1DUbY1xbm9RHh/vqRLg62957oU5ibTcPMUj6wZAYNNaXMm15MXUUhdRWFlBbkpqSOvJwsZpUXMqu88Iy34SdG5LEXhKg7kajTPxilfyhK/1Ak9nMwSt9QJN4ea+sbjOAOWVmMeKF667uI4VF9/1CEvsEIPQOxS+/wz8GhEf+O0NU3xK92tXLkeP9JteZmG7OnFSZ8QZhZVsDM8gIKcjXNNdkp+GXScneOdvWzp7Wbva1h9raGY/8OdXOoo/fEaDkny1hQXczyuml8pnEODTWlnD2zjLqKQrKmwOjU4tM+ANn8pt6ivDQVNELfYIQ3j/XGpsdGTZP9/NWjhLr733Kf8sJcZpYVUFNewMyy/BH/LqCmrIDq0nwqi/PIzdbaknRR8Muk0D8U4aVDnTzX3M6uI13sDcWCvrt/6ESfwtxs5lcVc96cCn7n/DoWVBfTUFPKgupi8nM0yhwPBbnZLKguYUF1ScLbewcivHGsh5bOflqO93HkeB8tnX0n/v3a4eOEuvuJJpjSqijKjR8DyY8fD4kdC6kuyaeqNI9Z5YUsrC4hL0cvEKmm4Je06BkY4vn9x3iuuY1nm9t54eAxBuIHWGunFbKgupi1K2pZOKOEBVUlLKguZmZZwZQYwQdJYV42i2aUsmhG6Sn7DEWitHb309LZx5Hj/SeOfYw8FvLyoWOEugdOeqGH2NTSwuoSzpldxjmzylgSv1QWT4K3Q1OYgl8mRGfvINv3t/NsczvPNbfz8qFOhqJOlsGy2nKuXj2PVfMruaC+kgr9UWeUnOzkj4P0DkROvCgcaO/h1cNdvHr4OE/uCvHg82+c6DezrIBzZpexZFYpS2aVsbxuGnMqi8bzaWQU83QvK0hCY2Oj63z8U8vRrj62NnewdV8s6F9tOY57bAS3vG4aq+ZXsmp+JSvnVaTsQKtktlB3P68ePs6rh4/zypvHefVwF7tbu4nE55HmVhaxZlEVaxZN5+KFVXpXAJjZdndvfEu7gl/eKXdnX1sPW5vbeW5fO1v3tbO/rQeIzcuvmDeNC+oruXD+dM6fO02rPiRl+oci7DrSzbZ97Ty5u41n97bRFZ8uOmdWGf9rcRVrFlVxQX0FRXnBm+B4R8FvZpXAt4APAyHgr9z9vgT98oG/Bz4DFAL3Aze5++CIPp8FbgHmAi3A77n7r97u8RX8k4u7s+toN0/tDrF1Xztb93XQ2hVb3VFRlEtjfSWr6iu5YH4lS2eXafWGTJihSJSX3uhky+4QT+4O8fz+YwxEouRmGyvmVrBmURXvO6uac2vLA/HBtXca/PcTO6/PtcB5wE+Ai929aVS/W4APApcD2cCPgEfd/Zb47R8C/h+xF4bngFkA7v4Gb0PBn36RqPP8gQ42N7Ww+ZUjJ0b0tdMKT8zNr5pfwcLqkkD8QcnU0DsQYeu+dp7aHeKpPSGa3oxNOc4uL+DDS2fy4aU1rKqvJCdDBydnHPxmVgx0AMvcfWe87R7gDXffMKrvNuAf3P378evr4tfnxK9vAb7l7t8aS/EK/vToG4zwq10h/vuVFn7x6lHawgPkZWdx8aLpfOicGt531gxqp535B5dEJlp7eIDHXjvKo00tPLGzlf6hKBVFuXxgSQ0fWTqTdy+uyqipyFMFfzKTXg1AZDj043YA7030OPHLyOt1ZlYOdAONwMNmthsoAH4A/Lm7975lQ2brgfUAc+fOTaJMSYWO+B/G5ldaeGJniN7BCKUFOVxy1gw+vLSG9zZU62CsTFmVxXl8cmUdn1xZR8/AEP/zeiuPNrXwaFML/7n9EEV52by3oZqPLJ3JJWfPoLwwM3/Xkwn+EqBzVFsnkGjh7iPATWb2OLGpnhvj7UVAMZALfBJ4NzAI/BC4Gfjfozfk7puATRAb8SdRp7wDrV39/O1PX+XhHW8SiTozywr45Mo6Pry0hgvnT9eHaCTjFOXlcOm5s7j03FkMDEV5Zm8bj8anMh/5dQu52caaRVV87NxZfHjpzIx6EUhmqud84Cl3LxrR9qfA+9z9slF9C4F/BD4B9AN3AX9D7EBvGdBO7GDu3fH+a4Gb3f38t6tBUz3jJxp1Hth6kL9/5FX6BqNcfdE8Lls+m3Nry/VhKQmkaNR54eAxHm1q4ScvHeaNY73kZWfxnoYqPvauWXxwSc2Uedf7TqZ6dgI5ZrbY3XfF25YDTaM7xqdsbohfhqdrtrt7BOgws0PEzkUlk8DrLV389UMvs31/B6sXVPLVT5zLwlN8NF8kKLKyjJXzKlg5r4K/uvRsdhzq5Mc73uQnLx/m568eJS8ni/c1VPPby2fzgbNnUJw/9ZaJJruq5wFigf0FYqt6fkriVT218X6HgQuB7wPXuvvm+O23ApcCHyM21fMw8Et33/h2j68Rf2r1DkS47bFd3PXEXkoLcrj5Y+fwOytqtRpH5G3E3gl08KMdh/npy4c52tVPQW4W7z97Bh87dzYfPGfGpDtnVCrW8f8b8CGgDdjg7veZ2VzgFeAcdz9gZu8BvgPMAA4Ct7r7vSO2kwt8E1gH9AH/AfyFu/e93eMr+FPn8deP8n9++GsOtvfy6cY6/urSJTpFgsgYRaPO1n3t/OTlw/z05RZC3f3MKi/gi+9ZwGdXzZ00K4P0yd2AO3q8j7/58Sv85KXDLKwu5m8/cS4XLpie7rJEprxI1PnVrlbueHwPz+1rp6oknz9493x+d/W8tE8DKfgDKhJ17nt2P1/72ev0R6L84SWLWP/eBZPuLalIJnh2bxv/9NhuntwdYlpRLteumc/n19RTlqaDwQr+AOofinDT/S/ys6YW/teiKr5yxTLqq4rTXZZIxnv+QAe3P7abX7x2lNKCHH7v4nquWTN/wqdVFfwB0zMwxBfv2c6vdoXY+NvncM2aeh28FZlgv36jk39+bDc/a2qhKC+bq1bP49p3z2dGacGEPL6CP0CO9w1yzb9v5fkDHfzD2nfxqcY56S5JJNB2Huni9sd386Mdb5KXk8WtH1/Gpy8Y/7/LUwW/Po6ZYdq6+1l31zPsOHSMf163QqEvMgk01JTyzc+ez8//5L2snFfBX/zXS3z54SaGItG01KPgzyAtnX18ZtMz7DrSzaarG/mtc2eluyQRGWFBdQl3//4qfn9
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = torch.linspace(0.05,0.95,29)\n",
"accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]\n",
"plt.plot(xs,accs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Assemble the Data"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 36,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.BIWI_HEAD_POSE)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 37,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 38,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#50) [Path('01'),Path('01.obj'),Path('02'),Path('02.obj'),Path('03'),Path('03.obj'),Path('04'),Path('04.obj'),Path('05'),Path('05.obj')...]"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 38,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"path.ls().sorted()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 39,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#1000) [Path('01/depth.cal'),Path('01/frame_00003_pose.txt'),Path('01/frame_00003_rgb.jpg'),Path('01/frame_00004_pose.txt'),Path('01/frame_00004_rgb.jpg'),Path('01/frame_00005_pose.txt'),Path('01/frame_00005_rgb.jpg'),Path('01/frame_00006_pose.txt'),Path('01/frame_00006_rgb.jpg'),Path('01/frame_00007_pose.txt')...]"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 39,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"(path/'01').ls().sorted()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 40,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('13/frame_00349_pose.txt')"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 40,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img_files = get_image_files(path)\n",
"def img2pose(x): return Path(f'{str(x)[:-7]}pose.txt')\n",
"img2pose(img_files[0])"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 41,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(480, 640)"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 41,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im = PILImage.create(img_files[0])\n",
"im.shape"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 42,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAAB4CAIAAAD6wG44AABe9UlEQVR4nM39aaxl2XUmiH1r7X3OvfdNMeYUOWcymcwkk6Q4SKJUsqzS2C6pqlpuw3C7baAHN+A//tNA223YBRsFN2DDcBvucrddqEYZ1XbJrRooUUOVVJIocSZFMpOZJHOeIjIjMmN48eIN995z9l7r84+9z7n3RUQmSamq2geBiBv3nmGfvfaavjVs+Tt/5//+8Y9/nCQAFwJQQgCHYDhWnwCAa99z/C9lPFfEV1dQiJsOAhDc+v2tRzmHEIBlGFIfKxCSJB0wM6fnnNzcPbibWWdGd3d3M6fT3Zzm7kY4Sc9Og4EA3Qk6xWkEQYcTQJ0TIyAgOfwLkKSwniOQ9e9JA+gE6YJ6DgkCZALgJoQDYHlUhjOZ5Vnb7JzcdC93EsLK7cpAyDqe8kw6SQF8OAdexgzSHXASgJ4/fyH+2Cc+/lOf+akfPNf/bR8EHHDCjTRYtmxm2VL2vk8555QTmXO2nLIzWxYzz5bdDOZkds9mZurZMplJwtWNpJs5XACS5p7pBEHSzFI6DDGCNHdhY+6guzs9DxQlYfTxM0m6F0p3QP0NFCCSdEM2AHRPZtnEYOjMTp254yNPPdb3ybKXhezOYZkVApdlQVJAIYW0gboOEITTh6sEEFJiWU0jh5RP5T8ysOcaP4JeliWcMObxtRKUdApBJ6RwNMmYDYS7uxOidLhDIDQxJ0kz9jR3p9MGhjM6zLN15ubOnLN55+65nJFp5maWU7Ls7uaFSszmvZsTQiZzKywsZqQT7jRHItw9Ox0mJJwEPdIAkg7x7EXIwN1p68sslylxUlg/ExABK4UB1pkkCWiZM5DuBmhhcTdxN2N2T+YJCYvOZie2CqncnRRAK2nLdA6cSvpAXSkPLc9yL+Qv/xb5AoHGN14/PHvm3a7LZm5Myc3K+nMj6x1inwUghKSb5Cr+6Jnu7ubm5sxkdiS3npqHNwRZ3x4kEMgq9IbVRwLueXw39+S0YTEaCALuJjKIOhI0rljGyqJjFWW9u4OxTispoiiLHS5CKQzmJgIpIg6kU9kCEGHKfdf3sYkxNnSnUUREim4olIOTIlAJlQfEKyOQFAepKiKKgT0IUVFhJEihqAtUfeLSaFgGELB2GoEgDCqhTHXlL2G5/yCiV0cZVfm+fAZUle6VuiIav/H1L1+79o57gjA4ASur2K1cQJLwOnt1aVTNSEGRZUIQCIXoAIP0JJ1F1Ii7ZzM61QNAIkEc1HFwFigiObtAGzik6KpAiYUfBFImVoSgDwzjgEIFCjoBF9EilwRJCAFERRSqIztBpYUSwVQBgQi0nIQgAg1ilmbTuFgsQzsTBIEIEEIAEJBFVCQItQxKRSCCJgMiiICUp0gluYKiKhABlTAgE+amdKeTxgSDJ/O8c/Ys3LUoa4iNclMqLVWVdFKreBWwSE3AARQCi8MpUqSIExavXX5t1vaq0vddEyey4ieoSBHnWtZv/VtFVABVFU0iQaRRBJVUlq0IIBQiqpTXh6q2ASMfyBQQgapqua2Lxja65+VysdFOIBQIoMTaVWoiI0niSjNrEIFIIXAodyYhamW0w3SrSBCoKEWAQTENs1+eI0UexMCuT7GZaWhk+AmAM6gSMIC6prhUfZCK4gPbCTwn0r1pGlE1B2kQc5pl0IliQzhgKZm1zZbWB61z5E2MW36vBqqIQBWkuANlDAIoxKVYuyJxZ/PkHafumkymi8WinVBFNMS6PBWqQUQCJARFfdOoKgBURYQiUthE2K+GJSLlFjIcZbJDFSkCoaxEX59T27TzxSKEMGlUECEi4pWDKyXIeoVwWGhOEgqwrKrlso8hNs0UIGHjtXVFU0kh+mrqQwb7AygaT7hYHIlAhRFNbKaxmWhREiAAkwAUgenKNQJD6C4DI1Wh5wgBGqKoEBAFWNSqQhwwoxjhZrnrl31uN0RER2KO4rdMI+swIIBXVSEcXBgRkAZoWcdgz5FaO2cmJ87OptON2bLZmjZBQwgNoCKF2KqiqlJEhKCQG6qqoiJRlaJ2Y383xFNnzp6tukeDig5klSqkBFEaVYpS1FVFVFTV3V94/pmNjenZ02cvX979wGMPbm+fBgjJkLi2flUg7kKHax6MEZqjeAvm6dLFdzc2Nra2dsyKV8SVn0QjM0GypddFz2pYgXRxh5AQDUHAEBhiKxoFKwclSFkcSoZ1Di40YzHBpKwbhUJoZZGD5XsQAQQkQwlxJ9x78y6bcfBMpazWgXLlc1VTGKRDIf+ayzqYxJVxx/vERx96YGMSP/zkh15/48177j93//33QyRoDBpFNMYoqkGziIZQaCJlQYtAEEUMyNd2r8Rw4tTp0ytndW1AP/Bw950zp7rDyel7u0cfvifEeLuz6OZOmlk25mwpJZK5GNbZsuWtHTt16iSAnK0cxVwsDk/xiulWDS11upLVy3amQIYQtIkKMeQQY1At0r36vjLDwFZr/4JwgMIAAZEDBQxOepEUIYAIVCl+jhBiDnFVDao5KCLgoupQhxCEQgr5WARNU59TbHQpbi7BIigJSLGgpFJdQQABaOJDDz3+kY88dfbs2XPnHjx59mTTToZ76Q+mTLW28p13PgSEkdsGI+yHoy55x50Pbp7c3OPy1LnTIUbeYi6WQ1TFXVWVEoIUl1NEzVREIRJj0zTFGDZVUzVWtMNDCO5uZqBx+BZUFl+IDBAlVWIIjULMEENUDTIAMnUp1I9c/QMoQIpoJQqEpJIuEqRacJUFy+8qg36SYboEEoIUAVseSUH1k24zE1UTjjK8LIvx1EJ3UYHE2EzvuvschCdPn2radm1yVzdmNWNvD0ARkSyo1l/4iEaRoIM39Z5HmRVVIUOMxcGvh7urxhDKYldAB7jA3M2s2Fygl2VRVku5qvCxBneVELQRIChDaKSIV9anFMdnQB44EAjFDJUKXDVF64s4EFRVNRSlPJ5eXqQskGpWFWACQaDEwKiig7O0mpOV8YXqS0ldCWVabloTjBpUo7ibhML4txOsUh2rmw4KKqP/ZYgLUAqyoJTw/mcW+hazeSB2QQ8YQrEXQmGtECjSuFe6lqtJocDNVcUpChaTrfjp5ZwYI81FVDUAonXVce0+dSQrZiiojrPyqUDEzaiqg+EDiIkOdpgUHlZZHSpSMOIiOEfciWuA062zUWjKm08oRq4IgCjF05NQnJbb34i3V6cymJd/mUMAChowCCKcEt7/lqripGqZdC/+1sCUGoIWCayqpJUv3alagCmnA0HdrfpiBZkTFyeVDKIqQngIIQRzoTgoFc7gIBkxuOOFxsU7MIdAq4hTkUhmVagKnRQDHUKgkBYqFAkm4ioQBC1GX1kCuqYBKGoFngSEUA641sD9PvizVUi4K8uyEsSVxSW3p+K/+kOa5rZW1XtfIBgpV9bvyAhF1xaSlyXrXtixak2oirtIUDiFpLh78XQAhtiIKMRVJZuF0ErBF91XtF0Tg9WZKaPSanoN54hI5eCiIkWErA7rLe9EESk+uqxuM96/eFoVR1uJ6fHi0fqpiGb5Xaoz+yNN7r/0QwRt0xQlxNWy/UEXiQAIQUII1Y0TGZGT8Zvx+3K+qpbPWkGYY2eJBNWYkgmQc54fHakq1lyO8SjCcV3mrYY9OocixSstuDGKVoFiUNgVJijMVT7SCS9+9rE7UyFKCr1ch5ueO46kjhMqCIIAqILwYoD8xWn0lzwkpdR3nYh4CQr8wAtWi3I1m+s0vunLW7+RgcgDbVU1AkJn6lKZfg2hQDrD5Vi/2y0jud04a9Cl/FEUmFMCbjZVy4rRwZ6q0cXjN5eCLWPtdW73zNGur79GkG4eBtTph3Vu/mUehINQh9EixdFqMdffcyjig/UhsjZXIqIaSJj5SI91DV0CNZWboU4C1PLWQqUUKNihDq34JQIqfFjt1io2JaCKRAddBpur4ABVwKoQHjR6DRhK9YgkYAgKCYKIClS14rIcvFvSi0FOUZJUBwFb0wrQ6h2xqgGSWmP5Thgkx+I/1LHevGr+NR3FxXR3g1NLiO1HkyhF9ha7A1VMKmByTGD
2020-03-06 18:19:03 +00:00
"text/plain": [
2020-09-03 22:51:00 +00:00
"<PIL.Image.Image image mode=RGB size=160x120 at 0x7F51A2492390>"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 42,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im.to_thumb(160)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 43,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)\n",
"def get_ctr(f):\n",
" ctr = np.genfromtxt(img2pose(f), skip_header=3)\n",
" c1 = ctr[0] * cal[0][0]/ctr[2] + cal[0][2]\n",
" c2 = ctr[1] * cal[1][1]/ctr[2] + cal[1][2]\n",
" return tensor([c1,c2])"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 44,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([384.6370, 259.4787])"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 44,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_ctr(img_files[0])"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 45,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"biwi = DataBlock(\n",
" blocks=(ImageBlock, PointBlock),\n",
" get_items=get_image_files,\n",
" get_y=get_ctr,\n",
" splitter=FuncSplitter(lambda o: o.parent.name=='13'),\n",
" batch_tfms=[*aug_transforms(size=(240,320)), \n",
" Normalize.from_stats(*imagenet_stats)]\n",
")"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 46,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAFUCAYAAABPx8fsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9V7AlSZrfif1cRMQRV2Xe1JVZuqpFTevRMxjRmCHkEiBhAIYAbbkG7i7JJcGHfaQZjEYjH/iyRpotaUY+0IhdcjmwxXJBAAR2MArTg0HPTI/q7unualVdMrW48pwTwt0/Prh7RJyborq7qitnaelpN8+958QJ4f75J/6fUiLCk/FkPBlPxpPxZDwZ9w/9uG/gyXgynown48l4Mv6sjidC8sl4Mp6MJ+PJeDIeMp4IySfjyXgynown48l4yHgiJJ+MJ+PJeDKejCfjIeOJkHwynown48l4Mp6Mh4wnQvLJeDKejCfjyXgyHjLsoz48OjoWCZoQBO893/72twB48cWX0FqhtY6vRrFarZjNZnjv0cpQlBZjNEo9+Nz3J54IMnpXiwICwTe889brvP36W2zvnOHw6JDDg3ssF0fUdU3XtXSdp20dTd2yWtUcLxa03QoIWDvFlhPKyQRtC6ydMJud5vz5S1y4cIGz586ztbPJbFZRTRTWEJ8JjeBxXUvXeL7+9a9z594d6rrlwvmL+NBx8eIFUIG68RwvliyOl+zt73Pv3h0O7+2xONynWe3RtktcFwBPURZszLfY2NxiOptRTQq0NQQXsEYjAkUxRRuLAlScDFQ/S8Q5FYWIAhUgKFBCoCP4gDEGROE9iAjBd7iupWlr6uUxtfc8e+U5RMALSIjHiaj4GoQgwn/wH/39h6zeD3a89u1viy0sRVlSliVFUWCtxjnHcrVitaxZLVfUdU1dL6nrBXXT0DYNTdPRNC3OOQpbUVYTqqqgrErKScVkMmEymTCfTJlOpsxmM2azKQDL1YrlYslqVbNarVitaurViqZp4k/b0LYNnXNYW1CVFWVlKYqCsppSFiVFWWC0wZiCophgClBacB20raPrOpxzdF2L61qc83Suw3UO7z3OeUJwtK5DJO47nCMETxDBB4MED8ETQkcIgRCE4AOd63o6CV5wzhFCQKQleEfwQhBPEIeIJ4SAl4AESevv4u8BRDwSXKSfdJz3If7uA77z6LLiJ378M2ilcE4STcY9LCIEht972hJBxEMQRHQ6zvefEYT/7f/+P/nA6e6X/4v/XIy2GFPQtQ3WahRxv0kACSHOjwjegxcHAkoE8ULnPCFA1wV8CATicwoBkUCQgFIaFCjifChlEAFtFEobtIr8VAHKGLQ2KKUxRtCmQGuLVhowEDRKKYy2KG0wVqOtQilNURRxzxQWYw3GGIqiQBuDURatDMbEexEVEAJag8agtSKEAAhKgVYqPgOBxeKQN15/nTfeeJOdnW1efPFFTp3a5s033uSbr72G1vD80y9x7vw5Vsslt2/foalbNrc2mU0rvOto2462aXC+xbuGrmvwweF9G2nc+UhfzuG9wzmHdx3ee7xbIcFz+YUP86lPfpq2bREh0lZQCB5E4n5IPC0I8ZgQooSRQPCBo6MFEmBra442mr/+N37poTT3SCEZbyARL0JdN2gdjU+lVPqJxF4WE7Q2SBCMMaiHScc01pi+xHfU6ENRoNBoM+PS5Ze48vRLKF0QCDjX0q5qjo8OONzfY3/vNnv37rJcHFM3HXViRk3bUi866tWS+uiAzrcgnn1tuPF2yavVnHKyQTXZYmPzFDunznB6d5fd3dNsbk2ZziaUlaawmpc/+mFecI6u8xwfHXP9+lW+/vVXEfGsljWXr1xhWhl2Ll/i4tmzHC2XHB8vONrf5/DgNkeH96gXxzTNgr29Pfb39rDFhGpasbExYT6bUU503AxKE4JDaY1CQxC0UqAUSsV1QXScMQEhQM+AQmI+Ic5j/75EoRhCnFelccEjooHhcxF5gALzwQ5jDEYbtDIodLrfgtl0ymy6CafjJo5DoqKQ/lSi+k+UkkRoiaB4ME16FZ95a1aytbuTThtZpE6nFhGcgPM+CrMOnDOD8lE7urajaRuaxtEuOur6iKauaboVXdfRtQ3O1zjncK7Ddy0uMwLf4n2XhFBkEBIcPnRI5/u1DSKE4CB4RKLwjMIJgndp3w0CSwCC6xk2gIREN6IQJDLwPF8CiUf2SmsUxMOPl4B4QY8IRecpI3/34SMuiQLR8TgR4qpJf48f9Dh7bpOuFU6fOs/du3ucP38GrXV6ZoZ5F0UICo8jeI94jwrxWUJI+1AJEBDxON8lxh1QKgmgpHBIIJ07KkgheDzxO851eN8gXtM5F4/zcS2sNtR1jSnUaI9rEBV5hlIoFa8XeYghWytKCVprBBX3ltZoDcootEr8RukooJOQ1joaDsqAVfDyC8+htGF50FAf3WZip3zsox9HKTCqZHG0pG0jnYsI2iiqScFy2XL95g1u377DxYsXuHLlabRW3Lh+nes3bjCbzTh3cZfJxLJYrFgsjnGuQ6s4f6tVTdfsYWyVlC6V+F0AdL8Pop4WKTA+tQw8Ms3/bDJJAtRHY+MR49FCMkAIkrQooSgs8/nGcFPp8gSh6xxt1zKpbLo/C8q8K3Gq/r/xLzL8rwRdGFCJCJShKCeU5YT51g7nn3oGxONcS71acLS/x/7eXQ7273F4cMDR8ZKmaaK12TjqZsVqtcJ3Hft7d1gcHSByHWU0tigoyinlpKKsNphvbLK1vcPO6R22N7fZ3NxgMplQFAWXLj3NxUuXCKHl8PCY5fGC69ev07Y1GxsbIHB2d4fTOxV1e4rloub4aMHx0QFHh3dZHh/SNg2r5ZLjoyOMUpRVYDbbZmNrk+l0SlVNMNailSIom7RPicIyCQdFgMRi0h5NDJBkRUjP8CNBCBA3Ui9mZG0RHvu48vQlQCNq8AYoUT3nDQwCINKn7p+3CVGQRKYTBYDzDnGe4CJzci5A6/Hp785B2wltm628Dtd5OudYJS3WOYdvWrqmoXMdXVfj/Arn2ijgQof3kXH6kCw/Hy08xBOkIwTfr0EIAXoLPik3Sdul37wdIh4zstBCpgEJONcR0FhbxI0uiTECOjFLAKWiwqF1Zp7pPa0TU41MkMw6tY1MVaUJT/cYQrpPBDxQ2HjODG+kV1EKgqDwvcUUvxCVlSgSszCJ1+xrmqjH4wH6kR/+GZxzTKczFosVm9tbWGt7AyHyPAZLmWEdg/O9NR9CiIpU8ATfslwec3i45PSpMyilonXvGyT4xF+jJe99tKCCJEs0hHiO0CJBEYLHeUfnAkprVssFVVklpWqg5bxG3kd69CFEFCFE5U4EfJe+430U9IRIUyHzgLg+USnXgI/GUFqrKHAVWmfkwhKwaA1GmfSZTt+33LlxwN2bYIxmY2rZfOYplFLcu3MDrRRGC+fP7VIUtkfTmqblzp19bGG5dOkC29vb3Lm7z5uveV5//Tqbs23Onz+H84579+6xWtZsbGyytbXZKyPZIs7rNqylQwgsFzWbWxPkXfSydxGSqtdeRWBvb5/Ll6+kzxKRx7nEe4+x8eZE6eGG0jR+byw4bnQRicxcotWa8Y9spKq4O0EZimJKUUzZ3N7l0jMvICFCpavlgqOjAw727nGwt8f1G1c5ODpClGLn3GkInrZtWS1a6rqhWe6xOoyMWBeawhqsLSjKkmKyxWy2xXzjNJubm8zmE+azgqK0FJM5z77wIkbAuZYbN65z+9Ytjg6PUUqxdWqbU6dn7J6d0zTnWC5XLI6POdq/x+L4MFocqwXNcp+9O3cxRcVsNmNja87GbMpkMsPYCJmAQlRIsxp6os1EEUIYCcBsgXSExLx1USbGl2loXTl5tB3wgx9f/toN2q6jdQ1t29K1HaEVnAt0XYd3jq5zyTrzeBehxc61BBc/d84TfBehRlqca1CoxChCFFgZgknKRJAIiQffke2ogRFGK80nphYFsEqaqaAlwvNCFxlqEuJxehWobPm7uG9UgRGVIK0Es+nIlJRSiClRahqZCEmTVxptdDxOC8ZEKCwiLgW
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 576x432 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dls = biwi.dataloaders(path)\n",
"dls.show_batch(max_n=9, figsize=(8,6))"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 47,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3, 240, 320]), torch.Size([64, 1, 2]))"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 47,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = dls.one_batch()\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 48,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 22:51:00 +00:00
"tensor([[-0.3375, 0.2193]], device='cuda:5')"
2020-03-06 18:19:03 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 48,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Training a Model"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 49,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18, y_range=(-1,1))"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 50,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [],
"source": [
"def sigmoid_range(x, lo, hi): return torch.sigmoid(x) * (hi-lo) + lo"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 51,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD7CAYAAABwggP9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAoHklEQVR4nO3deXxV1bn/8c8DYQiEEAIhjGFGRpmCsxVr61yhYlsUEaRKRW21t7Xqrd62atX6a+3oRKvizNWKVetUZ4taMQwBQQjIEAYhA5CRJCR5fn8k8caYEAI72SfJ9/16nRc5a6+9eJKcc56svdbay9wdERGRmtqEHYCIiEQmJQgREamVEoSIiNRKCUJERGqlBCEiIrWKCjuAIPXo0cMHDhwYdhgiIs3KsmXLstw9oWZ5i0oQAwcOJCUlJewwRESaFTPbWlu5LjGJiEitAk0QZna1maWYWbGZLayn7o/NbJeZ5ZjZQ2bWodqxeDN7zswKzGyrmV0UZJwiIlK/oHsQO4HbgIcOVsnMzgBuAE4DBgKDgV9Vq3IPUAIkAjOB+8xsdMCxiojIQQSaINx9sbv/A8iup+ps4EF3X+Pue4FbgTkAZtYZmA7c7O757r4EeAGYFWSsIiJycGGNQYwGUqs9TwUSzaw7MBwoc/e0GsfVgxARaUJhJYgYIKfa86qvu9RyrOp4l9oaMrN5leMeKZmZmYEHKiLSWoWVIPKB2GrPq77Oq+VY1fG82hpy9wXunuzuyQkJX5nGKyIihymsdRBrgHHA05XPxwG73T3bzIqAKDMb5u4bqh1fE0KcIiIRpaS0nIy8InbnFpGRW1zxb14xPzhlCF2j2wX6fwWaIMwsqrLNtkBbM+sIlLp7aY2qjwILzewJ4HPgJmAhgLsXmNli4BYzuwwYD0wFTggyVhGRSJSz/wDb9hSyfe9+tu8tZMe+/Xy+r4idOfvZua+I7IJiam7j07aNMW1C38hOEFR80P+i2vOLgV+Z2UPAWmCUu6e7+6tmdhfwNhANPFvjvCupmCqbQcWMqPnurh6EiLQIeUUH+CyzgM1Z+WzOKmRzVgFbswtI31PIvsIDX6rbqX1b+sZF0zsumlG9Y+nVtSO9YjuS2LUjiV060jO2A/Gd2tOmjQUep7WkHeWSk5Ndt9oQkUiRX1zK+l25rN+VT9ruPDZk5LExI5/ducVf1Glj0LdbNAO7dyYpvhMDuneif7dO9I/vRL9u0XSNbodZ8B/+1ZnZMndPrlneou7FJCISlqz8Yj7ZkcOanbl8siOHtZ/nsjW78Ivjndq3ZVjPGE4amsDQnjEMSejM4IQY+sdH0yGqbYiR100JQkSkgYpLy/hkRy7Lt+5l5fZ9rEzfx459+784PrB7J0b3ieWCif0Y2TuWo3p1oW9cdKNcBmpMShAiIvXILy5l2da9fLQpm6Wb97BqRw4lpeUA9I2LZnxSHHNOGMjYfl0Z1SeW2I7BDhaHRQlCRKSGktJyVqTv5f2NWSzZmEXq9hzKyp22bYyxfbsy+/gBTBrQjYkDutGzS8eww200ShAiIsCunCLeWpfBO+sz+OCzbPKLS2ljcHS/OK44ZTDHDe7OxKRudO7Qej42W893KiJSjbuTtjuf19bs4o1Pd7Nqe8UdfvrGRXPe+D6cMjyB4wZ3D3xtQXOiBCEirYa78+nneby0eievrN7FpqwCzGBC/zh+duZRfGNkIsN6xjT6tNLmQglCRFq8bXsKeSF1J/9YsYMNGfm0bWMcNzieuScN4vTRiS16HOFIKEGISItUWFLKK6t38fdl2/lwU8UWNZMHduPWaWM4e0wvusd0qKcFUYIQkRZl7c5cnly6lX+s2El+cSkDunfiJ98czrcn9qVft05hh9esKEGISLN3oKycVz7ZxcL3N7M8fR/to9pw7tjezDgmickDu2lM4TApQYhIs7WvsIQnPkrn0Q+3sDu3mIHdO3HTOSO5YFI/4jq1Dzu8Zk8JQkSanR379vPgvzez6ON0CkvKOHlYD+44fyxThvdsdreziGRKECLSbGzNLuDetz/j2eXbAThvXB8u/9pgRvauuQmlBEEJQkQiXnp2IX94M43nV+6kbRvj4uMGcPnXBtM3Ljrs0Fo0JQgRiVi7c4v481sbWLR0G23bGHNOGMgPvjaYnrFat9AUgt5yNB54EDgdyAJudPcna6l3PxW7zVVpB5S4e5fK4+8AxwFVW5XucPejgoxVRCJXfnEp97/zGX9bsonSMufCY5K4+utDSVRiaFJB9yDuAUqARCr2kn7JzFJrbhfq7lcAV1Q9N7OFQHmNtq52978FHJ+IRLCycmfRx+n8/vU0svJLOG9cH356+lEkddf6hTAEliDMrDMwHRjj7vnAEjN7AZgF3HAI550bVCwi0vykbNnD/zy/hrWf5zJ5YDf+Nnsy4/vHhR1WqxZkD2I4UObuadXKUoFT6jlvOpAJvFej/A4zuxNYD/zc3d+p7WQzmwfMA0hKSjqMsEUkTJl5xdzx8qcsXrGD3l078peLJnDO2N5a3BYBgkwQMUBOjbIcoEs9580GHnV3r1Z2PbCWistVM4AXzWy8u39W82R3XwAsAEhOTvaax0UkMpWXO0+nbOP2lz+l6EA5V506hKtOHUqn9po7EymC/E3kAzUnI8cCeXWdYGb9qehhXF693N0/qvb0ETO7EDgb+HMwoYpImDZl5nPDs6tZumUPxw6K59ffHsvQnjFhhyU1BJkg0oAoMxvm7hsqy8YBaw5yziXAB+6+qZ62HVB/U6SZKyt3Hlqymd/+az0d27XlrulH853kfrqcFKECSxDuXmBmi4FbzOwyKmYxTQVOOMhplwC/qV5gZnHAscC7VExz/R7wNeDaoGIVkaa3OauAnzy9kuXp+/jGyERu//YYrWeIcEFf7LsSeAjIALKB+e6+xsySqBhTGOXu6QBmdjzQD3imRhvtgNuAEUAZsA6Y5u7rA45VRJqAu/PU0m3c+s+1tI9qw++/N45p4/uq19AMBJog3H0PMK2W8nQqBrGrl30IdK6lbiYwOci4RCQcewpKuP7ZVby+djcnDu3O774znl5d1WtoLjRdQEQaxUebsvnRohXsLTjATeeMZO6Jg3Sn1WZGCUJEAlVe7tz7zkbufj2NAd078+DsyYzp2zXssOQwKEGISGD2FZZwzaKVvJuWyXnj+nD7+WOJ6aCPmeZKvzkRCcQnO3K44vFlZOQW8+tvj+GiY5I0EN3MKUGIyBFbvHw7Ny5eTXzn9jx9xfG6h1ILoQQhIoetrNy569V1PPDeJo4bHM9fLppIj5gOYYclAVGCEJHDkld0gGsWreStdRnMOm4A//OtUbRr2ybssCRAShAi0mA79u1n7sMfszEzn1unjmbW8QPDDkkagRKEiDTIJztymLvwY/aXlPHIpcdw0rAeYYckjUQJQkQO2dvrMrjqyeV069Sex+Yfy1G96rubvzRnShAickj+vmw71z+7ihG9uvDwnMm60V4roAQhIvV64N3PuOOVdZw0tAf3z5qkxW+thH7LIlInd+fOVyqmsZ57dG9+991xdIhqG3ZY0kSUIESkVuXlzk3Pf8KTH6Uz67gB/PK80bTVzfZaFSUIEfmK0rJyrvv7Kp5bsYP5U4bwszOO0m0zWiElCBH5kpLScn701ApeXbOL6844iqtOHRp2SBKSQJc9mlm8mT1nZgVmttXMLqqj3hwzKzOz/GqPKQ1tR0SCVVJazlVPLufVNbu4+dxRSg6tXNA9iHuAEiCRij2pXzKzVHdfU0vdD939pADaEZEAFJeWcdUTy3nj0wx+dd5oZp8wMOyQJGSB9SDMrDMwHbjZ3fPdfQnwAjArjHZE5NCVlJZ/kRxumarkIBWCvMQ0HChz97RqZanA6DrqTzCzLDNLM7ObzayqN9OgdsxsnpmlmFlKZmbmkX4PIq3OgbJyfvhURXK4ddoYLtF9laRSkAkiBsipUZYD1LYW/z1gDNCTit7ChcB1h9EO7r7A3ZPdPTkhIeEwQxdpncrKnf96OpXX1uzmf84dxazjBoQdkkSQIBNEPhBboywWyKtZ0d03uftmdy9399XALcAFDW1HRA5
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(partial(sigmoid_range,lo=-1,hi=1), min=-4, max=4)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 52,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FlattenedLoss of MSELoss()"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 52,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.loss_func"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 53,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2020-04-28 17:12:59 +00:00
"text/plain": [
2020-09-03 22:51:00 +00:00
"SuggestedLRs(lr_min=0.005754399299621582, lr_steep=0.033113110810518265)"
2020-04-28 17:12:59 +00:00
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 53,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEQCAYAAAB80zltAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAA320lEQVR4nO3deXyU5bnw8d+VhSRkhWzsIEvYDUgUBAUUBLFaF9T3qHXpaavVY9vT1ta259BaW+3y9u1p9agtrbVqLbVWqBu4sAviEheWQAj7ng1IyL5MrvePZ0KHYbKR2ZJc389nPmbuZ7uekeSae3nuW1QVY4wxxp8iQh2AMcaY7seSizHGGL+z5GKMMcbvLLkYY4zxO0suxhhj/M6SizHGGL+LCnUA4SAtLU2HDRsW6jCMMaZL+fjjj0tVNd3XNksuwLBhw8jNzQ11GMYY06WIyIGWtlmzmDHGGL+z5GKMMcbvLLkYY4zxO0suxhhj/M6SizHGGL+z5GKMMcbvLLkYY0wPdfhkNTX1roCc25KLMcb0UD9Yto3rntgYkHNbcjHGmB6o6FQtG3aVcMW4zICc35KLMcb0QK98doQmhesvGBiQ81tyMcaYHkZVefnjI0wanMKI9ISAXMOSizHG9DDbj51iZ1EFCwNUawFLLsYY0+Ms/eQI0ZHCNdkDAnYNSy7GGNODNLqaeOWzI8wZk0lK714Bu44lF2OM6UHe3VVKaWU9NwSwSQyCmFxEpK+ILBORKhE5ICK3trDfXSLiEpFKj9dsj+2VXi+XiDzu3jZMRNRr+6Lg3KExxoS/f3xymD69o5k9OiOg1wnmYmFPAPVAJjAJeENENqtqno99N6nqJb5OoqqnhzaISDxQBLzktVuKqjb6JWpjjOkmTlbV8872Im65cDC9ogJbtwhKzcWdBBYCi1S1UlU3AK8Ct3fy1DcCxcC7nTyPMcZ0ey99fIj6xiZunTo04NcKVrNYFuBS1QKPss3A+Bb2nywipSJSICKLRKSlGtadwHOqql7lB0TksIg8IyJpnYzdGGO6vKYm5YUPDnLhsD6M7pcY8OsFK7kkAOVeZeWArztcD0wAMnBqO7cA3/HeSUSGALOAZz2KS4ELgaHAFPf5X/AVkIjcLSK5IpJbUlLSoZsxxpiuZsPuUg4cr+YL0wJfa4HgJZdKIMmrLAmo8N5RVfeq6j5VbVLVrcDDOM1f3u4ANqjqPo9jK1U1V1UbVbUIuB+YJyLe10ZVF6tqjqrmpKend+LWjDEm/P3l/QOkxvfiygn9gnK9YCWXAiBKREZ5lGUDvjrzvSkgPsrv4MxaS0vH0sLxxhjTIxwrr2HljiJuyhlMTFRkUK4ZlOSiqlXAUuBhEYkXkRnAtcDz3vuKyAIRyXT/PAZYBLzitc90YCBeo8REZKqIjBaRCBFJBR4D1qqqd5OcMcb0GEs+PIQCt00dErRrBvMhyvuAOJzRXUuAe1U1T0SGuJ9Hab7rOcAWEakCluMkpUe9znUnsFRVvZvVhgNv4jS3bQPqcPpsjDGmR2pwNfG3Dw8yKyudwX17B+26QXvORVVPANf5KD+I0+Hf/P4B4IE2znVPC+VLcBKXMcYYYOPuUoor6vjpRcGrtYBN/2KMMd3aqh3FxEVHMjMruAOXLLkYY0w3paqs2lHEjJFpxEYHpyO/mSUXY4zppvILKzhaXsvcsYGdR8wXSy7GGNNNrdpRBMDlYyy5GGOM8ZOVO4rJHpRMRlJs0K9tycUYY7qhkoo6Nh8uY87YzJBc35KLMcZ0Q2t2FqMKc0LQ3wKWXIwxpltataOI/smxjOt/1tSKQWHJxRhjupnaBhfv7irl8jEZiIRmakVLLsYY0818sO8E1fUu5oaovwUsuRhjTLfz5rZjxPeK5OIRqSGLwZKLMcZ0Iw2uJlZsK+SKcZlBfyrfkyUXY4zpRjbuLqWsuoHPnT8gpHFYcjHGmG7k9S3HSIyNYmZWWkjjsORijDHdRF2ji7fyCpk3rl/QVpxsiSUXY4zpJt4tKKWitpGrs/uHOpTgJRcR6Ssiy0SkSkQOiMitLex3l4i43KtTNr9me2xfKyK1Htt2eh0/R0TyRaRaRNaIyNDA3pkxxoSH17ccJaV3NJeMDG2TGAS35vIEUA9kArcBT4nI+Bb23aSqCR6vtV7b7/fYNrq5UETScJZFXgT0BXKBF/19I8YYE25qG1y8s72IK8f3Izoy9I1SQYlAROKBhcAiVa1U1Q3Aq8Dtfr7UDUCeqr6kqrXAQ0C2iIzx83WMMSasrN1ZTFW9i6tDPEqsWbDSWxbgUtUCj7LNQEs1l8kiUioiBSKySESivLb/zL19o2eTmft8m5vfqGoVsKeV6xhjTLfw2pZjpMb3YtrwvqEOBQheckkAyr3KyoFEH/uuByYAGTi1nVuA73hsfxAYDgwEFgOviciIjl5HRO4WkVwRyS0pKenY3RhjTBipqmtk1Y4irprYn6gwaBKD4CWXSsB7as4koMJ7R1Xdq6r7VLVJVbcCDwM3emz/QFUrVLVOVZ8FNgJXncN1FqtqjqrmpKenn/ONGWNMqK3cUURtQxPXZIdHkxgEL7kUAFEiMsqjLBvIa8exCrQ2rafn9jz3eYHTfT0j2nkdY4zpkl7bfJT+ybHkDO0T6lBOC0pycfd9LAUeFpF4EZkBXAs8772viCwQkUz3z2NwRn694n6fIiLzRSRWRKJE5DZgJvCW+/BlwAQRWSgiscAPgS2qmh/oezTGmFAor25gXUEJV5/fn4iI0Eyv70swG+fuA+KAYmAJcK+q5onIEPfzKkPc+80BtohIFbAcJyk96t4WDfwUKAFKga8B16nqTgBVLcHpp3kEOAlMBf4tGDdnjDGh8FZeIQ0uDasmMQDvUVgBo6ongOt8lB/E6Yhvfv8A8EAL5ygBLmzjOisBG3psjOkRXt18lKGpvZk4MDnUoZwhPIYVGGOM6bCSijre21PKNecPCNmKky2x5GKMMV3Uim3HaFL4/KTwahIDSy7GGNNlvbb5KKMzE8nK9PXIYGhZcjHGmC6ouKKW3AMnuWpi6GdA9sWSizHGdEGrdhSjCvMnZIY6FJ8suRhjTBf0dl4hg/vGMToMm8TAkosxxnQ5lXWNbNxznHnj+oXdKLFmllyMMaaLWV9QQn1jE1eMC88mMbDkYowxXc4724tI6R0dVnOJebPkYowxXUiDq4nV+cXMGZMZNtPr+xK+kRljjDnLR/tOUF7TENZNYmDJxRhjupS3txcRExXBzKy0UIfSKksuxhjTRagq72wv4tJRafTuFbR5h8+JJRdjjOki8o6e4khZTdg3iYElF2OM6TJWbDtGZIRwxbh+oQ6lTZZcjDGmC1BVlm8t5OLhqfSN7xXqcNpkycUYY7qAnUUV7CutYsHE8K+1QBCTi4j0FZFlIlIlIgdE5NYW9rtLRFzupY+bX7Pd22JE5Gn38RUi8qmILPA4dpiIqNexi4Jzh8YYEzjLtxYSITCvCzSJQRCXOQaeAOqBTGAS8IaIbFbVPB/7blLVS3yURwGHgFnAQeAq4O8iMlFV93vsl6Kqjf4M3hhjQmnF1mNcdF5f0hNjQh1KuwSl5iIi8cBCYJGqVqrqBuBV4PaOnEdVq1T1IVXdr6pNqvo6sA+Y4v+ojTEmPOwqqmBXcWXYrt3iS7CaxbIAl6oWeJRtBsa3sP9kESkVkQIRWSQiPmtYIpLpPrd37eeAiBwWkWdExOeTRiJyt4jkikhuSUlJB2/HGGOCZ8W2QkRg/viu0SQGwUsuCUC5V1k54GshgvXABCADp7ZzC/Ad751EJBp4AXhWVfPdxaXAhcBQnNpMonufs6jqYlXNUdWc9PT0Dt+QMcYEy/Ktx8gZ2ofMpNhQh9JuwUoulUCSV1kSUOG9o6ruVdV97mavrcDDwI2e+4hIBPA8Th/O/R7HVqpqrqo2qmqRe9s8EfG+tjHGdAl7SyrJL6xgwYSu0yQGwUsuBUCUiIzyKMvm7OYsXxQ4vRqOOCvjPI0zMGChqja0cSyexxtjTFfyZl4hAFdO6DpNYhCk5KKqVcBS4GERiReRGcC1OLWPM4jIAndfCiIyBlgEvOK
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 54,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.050496</td>\n",
" <td>0.008238</td>\n",
" <td>00:36</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </thead>\n",
" <tbody>\n",
2020-03-06 18:19:03 +00:00
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>0</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.007744</td>\n",
" <td>0.004763</td>\n",
" <td>00:47</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>1</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.003334</td>\n",
" <td>0.000388</td>\n",
" <td>00:48</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>2</td>\n",
2020-09-03 22:51:00 +00:00
" <td>0.001468</td>\n",
" <td>0.000044</td>\n",
" <td>00:48</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"lr = 1e-2\n",
"learn.fine_tune(3, lr)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 55,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.01"
]
},
2020-09-03 22:51:00 +00:00
"execution_count": 55,
2020-03-06 18:19:03 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.sqrt(0.0001)"
]
},
{
"cell_type": "code",
2020-09-03 22:51:00 +00:00
"execution_count": 56,
2020-03-06 18:19:03 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2020-09-03 22:51:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAHzCAYAAACDns4pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9ebBlSX7fh31+mefc+9Zauqq7qrun9+mefbANMQCxk4ThEQUGJcqMsGQpKIccsuQ/HJBDkoNm0HTIYZC2JVM2RTtMmiJpijRIAhgSGGCwcDiYwcxgdsxM9/TePb1UVXd17W+7956T+fMfv8xzzr3v1avqRve7AyB/3bfevWfN5Zv5W/OXoqoUKlSoUKGjIbfsAhQqVKjQHycqk26hQoUKHSGVSbdQoUKFjpDKpFuoUKFCR0hl0i1UqFChI6Qy6RYqVKjQEVKZdAv9sSIR+WsioiKig2OfTsc+/TY8/8H8fBH5S3/Q5xX6o0dl0v0uJBH5zmDg3uzz15ZdzkyD8v79m5z/vnT+L4jITx5Qly0ReUJE/oqIrB9x8QG+DXwx/b0tOmjyTjRNz/oi8MbbV8RCf1SoWnYBCh1IXwdeS9/fBdybvv8+NqgBXn2zDxURAbyqtn/QAr5J+vPABPgk8CcGx1/AJqb7gfcD/xXwg8Cfu9mDRGSkqrO3s3Cq+p++jc+6APzQ2/W8Qn8ESVXL57v4A/w1QNPnwXTs/wI8AVwDGuA88A+Au29y38cwKa4FvhcQ4K8CrwNbwP8X+F8vvic9538EfAq4AexhEtzPpnMPDu6Z+yzU4RvAr6TvPzm47i+lYx74vcHxkwvP/i+AjwO7wN9M95wF/i5wDpgBLwF/HRgP3jsC/lZqpyvAfwv8nxbLCHw6Hfv0wr1/ObXzBLgOfA54bHD94ucvLZT7Lw2e90Hgl4BLqbwvAv9XYOOgcgD/K+A7qX9+FTi7bCyWz9vzKZLuH076GCb9voJpK+8B/gPgfZikuEgfxyan8+n3fwL8H9L314CfAv6txZtE5N8B/ik2Sb+KTT4/CPwLEfmL2CT0ReD7sEnqEvD8wjMeBD4M/N/fRP0WVfb/Kr37eaAVkVPYJP0AsAM8CbwX+C+BDwA/m+77P2KTF9gE9j8F1m6zDL8I/Jvp++vYpP2DwD0YA3s3vQbyxfT3QHOCiLwP+AKwkcr7HNZn/xvgh0Xkx1Q1Dm75k8BHsf7dAP4s8F8D/95tlr3QdzMte9Yvn8M/HCzpfhhwg2v+o8E1jxxw318fXOuBl9PxLwNjoAY+c8B7Xki//wdA0rG/k449O3jmd9Kxv39A+X8OCMBd6fdPDt7zPDZ5nhsc+5fpugcHx54ETgzK/1fT8csk6R74kcH1P4JNrnvp9y9hjGMdeCpfNyjjpxlIusCPD571tzGTDNiEe3axfRfqOyz3X0rH/kH6vQM8kI79LwfX/exCOQLwvenYL6Vjry0bi+Xz9nyKI+0PJ30P8GUR2U6OnL8zOHfPAdf/zcH3deC+9P2XVXWqqg3wz4c3iMidwEPp578LxPSu/ygde3eSOG9Ffx74vKpePODcw5hEdxyTHv8qJo0u0j9Q1WsAqhrSPQB3AOdTuX53cP0PYZLoSvr9z9RoB/jEbZT5o4PvfyO9E1U9r6qv3eSewyjbsT+nqi+l7/94cP4jC9d/S1V/P33Pzr273sJ7C30XUjEv/CEjEflRTHISTNL7NqaCvi9d4hfvWZgo9Cbf971q8P1F4KBJs75FWU9hUud/eZNL/kNV/fuHPSPR4kSXy7aN2VwX6dphxbqN971TdLsp/a4Nvmen5zLLXehtpCLp/uGjj9IPwA+p6g8C//B2b1bVLcy8APDnRKQWkRr4dxauu4iZDQAeB35MVX9IVX8I+IvAzw8m8930dzHc689hTODjt1u+26Qv5WIC/7NBuX4Kc079ImY3naTr/oIYrWH28FvRFwff/3MRcQAiclZEzqTjuc7cRpjbl9PfHxGRB9L3f3dw/iu3UaZCf0SoTLp/+Oibg+/fEpEngf/8TT7jb6S/P4RNrC8CP3DAdf/b9PdngQsi8nUROZ/u+bnBdU+lv/+2iHxVRP779PvPA4+r6pxz7W2gv4U5mTaBb4vIN0XkWeAq8M8w++8u8N+l6/8CZp/+Dr3J5Kakqp/BIgbAHHHnRORxLEIiaxRPDW55QkR+T0Qevskj/zomla+na58YlO3z3J7Jo9AfESqT7h8yUtXfwtT188AqNvj/kzf5mP8n8L/HvO3Hgc8CPz84v5fe9QuYZPgpLDrhfZj0+M8wiTLTX8EcYjPg+4EPJanyp3n7pVxU9RLGMP4uZvZ4H3AMkyj/MhZtAPC/w+p6AwtD+zgWNnY79BfS/U9ituP7MYk0R4D8KmZLv4xFUXyUm0RGqOqTwA8Dv4zFWT+GMY3/GvgZnY9cKPRHnLJHutAfIxKR48CKqr6efnvg17FJ8gJwr/4BgSEi/xbmef+Iqn71D1jkQoX+yFBxpP3xpIeAL4rIl7HY2u/BQp0A/vIfdMJNtAP8lTLhFio0T0XS/WNIInIv8PewRQ0nMXvjV4D/m6r+2jLLVqjQH3Uqk26hQoUKHSEVR1qhQoUKHSGVSbdQoUKFjpDKpFuoUKFCR0hl0i1UqFChI6Qy6RYqVKjQEVKZdAsVKlToCKlMuoUKFSp0hFQm3UKFChU6QiqTbqFChQodIZVJt1ChQoWOkMqkW6hQoUJHSGXSLVSoUKEjpDLpFipUqNARUpl0CxUqVOgIqUy6hQoVKnSEVCbdQoUKFTpCKpNuoUKFCh0hlUm3UKFChY6QyqRbqFChQkdIZdItVKhQoSOkMukWKlSo0BFSmXQLFSpU6AipTLqFChUqdIRUJt1ChQoVOkIqk26hQoUKHSGVSbdQoUKFjpDKpFuoUKFCR0hl0i1UqFChI6Qy6RYqVKjQEVKZdAsVKlToCKlMuoUKFSp0hFQm3UKFChU6QiqTbqFChQodIZVJt1ChQoWOkMqkW6hQoUJHSGXSLVSoUKEjpDLpFipUqNARUpl0CxUqVOgIqUy6hQoVKnSEVCbdQoUKFTpCKpNuoUKFCh0hlUm3UKFChY6QyqRbqFChQkdIZdItVKhQoSOkMukWKlSo0BFSmXQLFSpU6AipTLqFChUqdIRUJt1ChQoVOkIqk26hQoUKHSFVh538+Z//G/pz/9nPMaorEEDzGfshmr+DSiQd6K4RtdPDox2pzD0Lwu2VWAB13XsBosT80HRc5n+i9r8qUQWNDtVIiBENSmgDIUTatiWEQNO0tK19mhbatqFpZ8TY0qZrQwjEGOw5IRBjJGogakBjS9SABIUIUaO9OwiKojHa3/SdGGnbhjbs4JxHVdAIRLs+hkiIDRoDqlYp1WjXqaI0aKR/tirKbO5dMQY02vUhBkKwe0Kw37ENNG1gY/M4H/kTH4S2JbSRKA4FNFovaurzmN8TtStH10UOK49q/0l9EGPAqqx84hOf5OOf+LW+I4+QCrYLtpeF7UMnXUSpak9QBbXGEBHEacJIjzwrnEsYUIgCUVEgqhKjWmXSf62zAltNc+EhBsUp1F09+4rHaJerCjEKoY2EEGkiRMXAEQIhKjEEe29oCXHSAa8NLTG2qTxK27TEENEYadqGpmms8UK0v9oSQotqSH+tHtbA+ZqYRl9Izw4oAQnB2iB1jqTBpwoiQkuwH9HeD4qII0aAiErTjS+nEdTaTxCUmAaoQ8mA1fy/tT8JrKrpeklt1xK1JYYMzAaaQNNGXA3OC6IugV3SXCD9OxCcc/ZcIT03vyfYNeKsfBq6SUJVQUDyxy1lvjUq2C7YXhK2D51033hjwtNPnWc6bQkhMg0zQmiNk2gccBpFYipzAhitzfgohBBRjcSgRA2EEFBaokZitGeRGtfOgRNJHDZ1QBRCDIjMNxCqxm0hcaho3DhzK42psbDW0FSWGBOXaxNnjcQ45FrW0C5x1BBTg2PlQ/P73aA8CVAxAjH3Rno3hHaafno0wrSdUvmKuqrSiAtA6lABXA+2jGgRQSS9RxUnDtU8WUhGffdOEYeTBOQ0lTjnUGoiVj/vHE4Co0p
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x576 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-09-03 22:51:00 +00:00
"learn.show_results(ds_idx=1, nrows=3, figsize=(6,8))"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. How could multi-label classification improve the usability of the bear classifier?\n",
2020-03-18 00:34:07 +00:00
"1. How do we encode the dependent variable in a multi-label classification problem?\n",
"1. How do you access the rows and columns of a DataFrame as if it was a matrix?\n",
"1. How do you get a column by name from a DataFrame?\n",
2020-05-14 12:18:31 +00:00
"1. What is the difference between a `Dataset` and `DataLoader`?\n",
"1. What does a `Datasets` object normally contain?\n",
"1. What does a `DataLoaders` object normally contain?\n",
"1. What does `lambda` do in Python?\n",
"1. What are the methods to customize how the independent and dependent variables are created with the data block API?\n",
2020-03-18 00:34:07 +00:00
"1. Why is softmax not an appropriate output activation function when using a one hot encoded target?\n",
2020-05-14 12:18:31 +00:00
"1. Why is `nll_loss` not an appropriate loss function when using a one-hot-encoded target?\n",
2020-03-18 00:34:07 +00:00
"1. What is the difference between `nn.BCELoss` and `nn.BCEWithLogitsLoss`?\n",
"1. Why can't we use regular accuracy in a multi-label problem?\n",
2020-05-14 12:18:31 +00:00
"1. When is it okay to tune a hyperparameter on the validation set?\n",
"1. How is `y_range` implemented in fastai? (See if you can implement it yourself and test it without peeking!)\n",
2020-03-18 00:34:07 +00:00
"1. What is a regression problem? What loss function should you use for such a problem?\n",
"1. What do you need to do to make sure the fastai library applies the same data augmentation to your inputs images and your target point coordinates?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. Read a tutorial about Pandas DataFrames and experiment with a few methods that look interesting to you. See the book's website for recommended tutorials.\n",
"1. Retrain the bear classifier using multi-label classification. See if you can make it work effectively with images that don't contain any bears, including showing that information in the web application. Try an image with two different kinds of bears. Check whether the accuracy on the single-label dataset is impacted using multi-label classification."
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
2020-09-03 22:51:00 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
2020-03-06 18:19:03 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}