fastbook/clean/06_multicat.ipynb

1566 lines
647 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"# Other Computer Vision Problems"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Multi-Label Classification"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### The Data"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.vision.all import *\n",
"path = untar_data(URLs.PASCAL_2007)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fname</th>\n",
" <th>labels</th>\n",
" <th>is_valid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>000005.jpg</td>\n",
" <td>chair</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>000007.jpg</td>\n",
" <td>car</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>000009.jpg</td>\n",
" <td>horse person</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>000012.jpg</td>\n",
" <td>car</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>000016.jpg</td>\n",
" <td>bicycle</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fname labels is_valid\n",
"0 000005.jpg chair True\n",
"1 000007.jpg car True\n",
"2 000009.jpg horse person True\n",
"3 000012.jpg car False\n",
"4 000016.jpg bicycle True"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'train.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: Pandas and DataFrames"
]
},
2020-04-28 17:12:59 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[:,0]"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fname 000005.jpg\n",
"labels chair\n",
"is_valid True\n",
"Name: 0, dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[0,:]\n",
2020-05-14 12:18:31 +00:00
"# Trailing :s are always optional (in numpy, pytorch, pandas, etc.),\n",
2020-03-06 18:19:03 +00:00
"# so this is equivalent:\n",
"df.iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['fname']"
]
},
2020-04-23 13:41:55 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
2020-04-28 17:12:59 +00:00
" <th>b</th>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
2020-04-28 17:12:59 +00:00
" <td>3</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
2020-04-28 17:12:59 +00:00
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a b\n",
"0 1 3\n",
"1 2 4"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tmp_df = pd.DataFrame({'a':[1,2], 'b':[3,4]})\n",
"tmp_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" <th>c</th>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </thead>\n",
" <tbody>\n",
2020-04-23 13:41:55 +00:00
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <th>0</th>\n",
" <td>1</td>\n",
2020-04-23 13:41:55 +00:00
" <td>3</td>\n",
2020-04-28 17:12:59 +00:00
" <td>4</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <th>1</th>\n",
" <td>2</td>\n",
2020-04-23 13:41:55 +00:00
" <td>4</td>\n",
2020-04-28 17:12:59 +00:00
" <td>6</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2020-04-28 17:12:59 +00:00
" a b c\n",
"0 1 3 4\n",
"1 2 4 6"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tmp_df['c'] = tmp_df['a']+tmp_df['b']\n",
"tmp_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Constructing a DataBlock"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dsets = dblock.datasets(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(4009, 1002)"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"len(dsets.train),len(dsets.valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(fname 000224.jpg\n",
" labels tvmonitor bottle\n",
" is_valid True\n",
" Name: 113, dtype: object, fname 000224.jpg\n",
" labels tvmonitor bottle\n",
" is_valid True\n",
" Name: 113, dtype: object)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dsets.train[0]\n",
"x,y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'000224.jpg'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x['fname']"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"('009879.jpg', 'car person')"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(get_x = lambda r: r['fname'], get_y = lambda r: r['labels'])\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"('006350.jpg', 'aeroplane')"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return r['fname']\n",
"def get_y(r): return r['labels']\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-22 03:06:19 +00:00
"(Path('/home/sgugger/.fastai/data/pascal_2007/train/008663.jpg'),\n",
" ['car', 'person'])"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return path/'train'/r['fname']\n",
"def get_y(r): return r['labels'].split(' ')\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(PILImage mode=RGB size=500x374,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#1) ['chair']"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = torch.where(dsets.train[0][1]==1.)[0]\n",
"dsets.train.vocab[idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PILImage mode=RGB size=500x333,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def splitter(df):\n",
" train = df.index[~df['is_valid']].tolist()\n",
" valid = df.index[df['is_valid']].tolist()\n",
" return train,valid\n",
"\n",
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y)\n",
"\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y,\n",
" item_tfms = RandomResizedCrop(128, min_scale=0.35))\n",
"dls = dblock.dataloaders(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAACzCAYAAAD2UgRyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9ebQl13Xe99vnnKq605tfT+hGA2h0AwRATARHSZRIyhxNiaS8JMW2TMma7CxFiZOlJF62ZQ3kkuXElpfjaEhoKYwdS6QGSxRJ0SItEgRJkAA4E2N3o4EGenz95jtW1Rnyxzn39kMbbEo0gZbI+61113v31q26p6rOsPe3v71LQghMMcUUU0wxxRTf2lBXugFTTDHFFFNMMcWVx9QgmGKKKaaYYooppgbBFFNMMcUUU0wxNQimmGKKKaaYYgqmBsEUU0wxxRRTTMHUIJhiiimmmGKKKZgaBFNMMcUVgoi8SkQeFJFaRO6+0u2Z4spARO4WkX97pdsxBZgr3YApppjiWxa/DjwAvBHoX+G2TDHFtzymDMHzDBHJr3Qbpvirj2+SfnQE+EgI4ekQwvqVbswU3zz4JhkfzzumBsGfAyLyUyLysIiUIrIiIr+fPv9bInKfiGyJyKqIfFBEbtix37UiEkTkb4vIn4hIH/ilK3YiU1wRJEr0t0Tkl1M/2RaRfysizR3f+WkReVRERiJyTET+sYiYHdufFJF3isivicga8Kn0+Y+LyCNpvzURuUdEDuzY700i8rkdfffXRKS9Y/u7ReQ/i8hPisjJ1Lb3iciur3FO3yEinxKRbnp9SURev2P7jWk89NLr/SJyOG17lYgEQAP/Lo2RH5GId4nI4yIyFJETIvJLIlJ8A27DFH/JISI/KyLnRGQ99ct2+lxE5GdSf6hS//gHl+z79Y6Pu0Tkw6mPXhCR/ygi1zyvJ/6XCSGE6esyL+AXgB7w3wE3AC8C/kna9neBNwPXA3cCfwwcA/K0/VogAKeAHwIOAddd6XOavp73PnQ3sA28C7gJ+B5gBfg/0vafB04CbwOuA94EPAW8Y8cxnkzH+PnUD28G7gIs8HbgGuBW4MeBA2mf29L2f5V+943puP9+x3HfDWwBvwO8EPi21Jb/d8d3xv34R9J7DawDv0L08o+ktr8ybW+mY/xZauNdwMeA40CeXnvTMX8q/d8kOijvBF6WfvN7gbPAL1zpezh9PefjYzP10xcAb0jvfyFt/ylgCPxk6mt/HxgBP7bjGF/P+LiZOLf/QvrdW4HfA44CjSt9Xa7IvbjSDfjL/ALaqSP+zJ/z+4tpkvv29H48kf7slT6X6evKvdKE9ySgd3z2k0CZ+tgAeMMl+7wd2Nzx/kngzy75ztvSYj77VX733wP3X/LZWwAPXJPevxu4ABQ7vvMPgbM73u8HHgXelt4vpH79qq/yuz+Wzml5x2d70lh6+47PAvBDX+Pa/Y/AsSt9D6ev5+6VxseXL/nsN4BPp/+fBv63S7b/K+DEjvdfz/h4N/CeSz4rUt9965W+LlfiNRUVXh63AA3gw8+2UUTuAH4OuANYBiRtuoZEWSXc/xy2cYq/Grg/hOB2vP8U0VN+MdE7/oNEo4+hgYaI7AohXBgf45JjfgQ4ATwhIh8BPgr8xxDCatp+S/psJz5O7Kc3E714gEdCCOWO75wmLuAAhBBOEz2o8fuNpAr/UxH5aDrmH4YQHtvxuw/vaAchhPMi8lja9lUhIj9B9OKuJRpLhmlo81sBX7zk/WngdSIyCxwA7rlk+8eB/0FEWiGEQfrsLzo+XgIcFpHeJfs1iEzEtxymA+3Ph//ikZAi0iIaCgH4UeClxA4WiBP9TkwV1FNcCrnk/fcTDcvx61bipLRTbPeMfhRC6BENircRac6/DxwXkbt2fu2r/P7Oz6tn2XZp+575hRB+gkjJfgT4LuBBEfl7X+N35TLtQUS+H/hV4L3EsMmdwC8C2eXaMsU3BZ6tD6pL3u/Es/XPv+j4UEQW7Y5LXjcA35JpkFOD4PJ4mBirev2zbLsJ2AX84xDCx0IIjxCp1MtOpFN8y+IlIqJ3vH8FcRL8IrGPHQohHH+Wl3vWoyWEEFwI4Z4Qwj8lLtBngb+VNj9EXKx34ruIk+vD/7UnFEJ4MITwKyGENwK/SQyDjH/3FhFZHn9XRPYQJ9qHLnPI7wS+kI75uRDCMSJTMMW3KEII20QN1qX9+DuBJ3awA19t/8uNj88SdTaPP8u42/jGnslfDUxDBpdBCKEnIv8S+HkRGRK9oSbRe3kXMQb80+k71wK/zGU8oCm+pbEE/KqI/GuiuPQdwLtCCFsi8kvAL4kIxD5miAzBnSGE//WrHVBE3pKOdQ9RB3AXcDUXF/v/Hfi8iPwK8H8T++i/Af5DCOGpr/dEUrbATwDvJ8Z3rwJeCXw+feW3gX8KvFdE/meikfwviDTwey9z6MeAH0vn9SBRsPt9X287p/imwT8D/qWIHCPqDV4D/LdEseFXxZ9jfPwSMczw/6VxeYE4Rt4K/OsQwolv9In8ZcfUIPja+FliR/nviUKWDeCeEMKqiPwQsbP+KPAI8A+IyuopprgUvw90gU8SQ0q/B/wvACGEd4jIGeCniQvnkEhxvvtrHHODmLHwj4AZ4uL8TuC30nG/LCLfSzQ+foqowv594Gf+Ig0XkWuBJ4C/G0J4N5GaPQK8h8iSrQEfHB83hDAUkdcRx8s49ns3UTh5KTW8E/8X0RD6f4hz0weIqvF/8xdp7xTfdPh1op7kHwG/Ruzn/zCE8JtfY7+vNT4eEZFvS5/9KVE7cJqoNdj8xp/GX35IUlZOMcUUzxEkluU9HkL48Svdlq8HIvIa4oJ/y7ei1zTFFN8qmGoIpphiiq+FNwP/fGoMTDHFNzemIYMpppjisggh/E9Xug1TTDHFc49pyGCKKaaYYoopppiGDKaYYooppphiiqlBMMUUU0wxxRRT8DU0BMfu/cXwvo+e5s0/9A6yrAGAkwIbQKRE5Q20FlRKvdcECOBEc2Fzm8ceP0FwliePH4s/JoGyu8bv/e57cV549Wu+m2oYi0t94H1/wGy7oLe9xcGD13Ht4RvJTCxQFmzFw48+xszCPIPBAJzjVd/5nXzw/X8MwAtfeAs3HLmRqw8e5Od+7p9w/fXXIxJtnb/+pjfzH377d9Da4Lxj//797Nq1i2PHjgLw0ENf4fbb72BpeYmPfexjXH/9IWxtAdje3qauK7a3tjl8/WHOr6zQaMY2bW1voVAYY5jtzNLtdvHB89KXvxyAxx9/nFE5Yt+uPWxtrHHm3BluPBKrv25vnWZ21rC63Yesg8Wzb/csmxdihVqnCg7uLVg5tcmZMxVFnvOyO2Il2XKgue8rK8zsaTE322Bxdpmt7la8/ipne7vP3/zB/wbrahYXZvnBH/h+TLqOZ86cZW11jbnlRXbvu4ayVoiOD9xrNBrEEvceEVCiCEEIkzJLnkyERq7IlUWCw9ohAJ///H387nv/kNe+7o0cuflGlmaWOPnF+2IfWltjee8sMwz4jjf86PNetOl1b3xz0KIZ9rp02h0AdBCsgqCEps5Y724wW7QAaLRaWKWoyyFVOWKhM0vfWXTqT/1ul85sB+88TgJb233ml5cAsHVNJ89xzpHnOZ1mgU11qsr+EOsc83MdbOnYrkoWOh2qQbyGs/NzDKuSYa9H1m5z5vwKKtYlYO/sHJ3FWepRSbesCC5QNBuMylhtuGg06Xb7tI2m0y6wgwGbw0E67gKrG9vMtGdozDSpe318qnW0qz1DD09dW4bbXWpbs9RsIe1YaNMHaIiiCpZmXnB2bY1mM16n+YV5jGlQ9bqsbm7QmZujqQ2DYTyfmV2LbG1t0kbYHlUsLCyyvRUzuRrNJt7CxtY6cwtzNPMGm+txW2u2TfCOhU4HpRQrW9u08xxsbLNFUTQKPEJvc5O5+Rl6o3gdMq3xVc1H/uzD9AcDvHc0iwbtdny44/r6OpvbXYpGxoE9eyjLmtMrKwBUdU2qA4FSip2h1DzTFEVBWVZsbXevSOGxGhdUCAyC8PRGvFaPnXgEW21
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 648x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-03-19 13:21:55 +00:00
"dls.show_batch(nrows=1, ncols=3)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Binary Cross-Entropy"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 20])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.train.one_batch()\n",
"activs = learn.model(x)\n",
"activs.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-22 03:06:19 +00:00
"tensor([ 2.0258, -1.3543, 1.4640, 1.7754, -1.2820, -5.8053, 3.6130, 0.7193, -4.3683, -2.5001, -2.8373, -1.8037, 2.0122, 0.6189, 1.9729, 0.8999, -2.6769, -0.3829, 1.2212, 1.6073],\n",
" device='cuda:0', grad_fn=<SelectBackward>)"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"activs[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(inputs, targets):\n",
" inputs = inputs.sigmoid()\n",
2020-04-22 15:14:42 +00:00
" return -torch.where(targets==1, inputs, 1-inputs).log().mean()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.0082, device='cuda:5', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_func = nn.BCEWithLogitsLoss()\n",
"loss = loss_func(activs, y)\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Hello Jeremy.', 'Ahoy! Jeremy.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def say_hello(name, say_what=\"Hello\"): return f\"{say_what} {name}.\"\n",
"say_hello('Jeremy'),say_hello('Jeremy', 'Ahoy!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Bonjour Jeremy.', 'Bonjour Sylvain.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = partial(say_hello, say_what=\"Bonjour\")\n",
"f(\"Jeremy\"),f(\"Sylvain\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.903610</td>\n",
" <td>0.659728</td>\n",
" <td>0.263068</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.724266</td>\n",
" <td>0.346332</td>\n",
" <td>0.525458</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.415597</td>\n",
" <td>0.125662</td>\n",
" <td>0.937590</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.254987</td>\n",
" <td>0.116880</td>\n",
" <td>0.945418</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.123872</td>\n",
" <td>0.132634</td>\n",
" <td>0.940179</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.112387</td>\n",
" <td>0.113758</td>\n",
" <td>0.949343</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.092151</td>\n",
" <td>0.104368</td>\n",
" <td>0.951195</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
"learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(#2) [0.10436797887086868,0.93057781457901]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.1)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(#2) [0.10436797887086868,0.9416930675506592]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.99)\n",
"learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"preds,targs = learn.get_preds()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TensorMultiCategory(0.9554)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_multi(preds, targs, thresh=0.9, sigmoid=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXSc9X3v8fdXu7VZq2VsWYvBC3IwGIQdIGBDApicxgQcEkhJgNuGktS3p03IDZz25uY6l5KFNEkvtCnNJYSkKTG0ITRljbETB0NqES9gvAlLtmRhayRZthZrne/9Y8ZmLGRrwJJGmvm8ztHhWX4z850H+aPf/J7fPI+5OyIiEr+SYl2AiIiMLQW9iEicU9CLiMQ5Bb2ISJxT0IuIxLmUWBcwVFFRkVdUVMS6DBGRSeW1115rcffi4fZNuKCvqKigpqYm1mWIiEwqZrbvVPs0dCMiEucU9CIicU5BLyIS5xT0IiJxTkEvIhLnFPQiInFOQS8iEucm3Dx6kXg2GHRaO3s5dLSXQ0d7ONTRQ2tnH9npKRRmp1GYlR76b3YaBZlppCSrLyZnTkEvMgJ3p7N3gMNd/XT3DzAw6PQNBukfCDIQfGe5f9AZCAbpGwjSNxikrbOPQx0974T60R4CHb0E38MtIPIyUynMOvkPQElOBiW5GUzLTackN7Scn5mKmY3dQZBJTUEvCcfdOXpsgEBnD81He2np6uNwVx9tXX0c7u6jdcj64a5++gaD7+u1CrLSmJYTCuT503PCAZ1BSc47IV2QlUZ33wAtnX20dvbS1tVHS9c7y62dfbR09lLb3Mmre3s53N3/rtdJS06iOCedkojwL8nNoKwgk/LCTGYVZDJ1SuqZHjqZpBT0Elf6BoLsa+1if1s3gY5emjt6ae7oObEcCP/0Dgwf3HmZqRRkppGflUZpfibnl+aRn5VGQVYq+ZlpZKWnkJqcRGqyhf87/HJaShJ5mamkpyRHVXdaShp5mWmcMy17xLa9A4M0Hw29r3c+LfTSHB4K2tPcye9qW+joGXjXeysvyKSsMIuygimUF2RRVhj6Q1CSk0FSkj4RxCsFvUxKPf2D1LV0sae5k9pDHexp7mRPcyf1LV0MDBkbyctMZVpOOsU56VxcUUBxTvqJ9eKcdIqy0ynISiNvSuqkGBNPT0lmVkGol346HT39NLQdY39bF/tau9nX1k1DWzdbG9p55vW3GYw4TqnJxrSc8HBQTgYluemhTx65Ge98SsjJIHdKioaIJiEFvUxo7k7TkR627G9ne9ORULA3d7KvtevEWHeSQUVhFudMy+baBSXMmZZDeWEm03IzKMpOi7pXHW9yMlKpmpFK1Yzcd+3rHwzydnsP+8J/BA60H+PQ0dBQ1luBTja+1cLRIZ8IANJTkk76A1mUnU5xdhpFOekUZ6dTdGJ7Gtnp+qMwUSjoZULp6OlnW+MRtjS0n/gJdPQCoV5nZVEW556Vw8fOn8GcadnMKcmmsigrYcP8/UpNTqKsMJOywkwunzN8m2N9g0OGh0I/zR29tHT2sr+1mz/sO0xbdx8+zAnm9JQkKouyOL80j4WzpnJ+aR7zpueQOgk+NcUbBb3ETDDo7Dh4NBTo+0OhXhvoPBEas4uyuPycIi4oy+OCWXnMn55LWopCYrxMSUumvDCL8sKs07YbGAzS1t1HS0fopPHxn0BHL7sPdfLCmwf5eU0DEAr/BTNyWVga+n+6sHQqFYVZOj8wxhT0Mq5aOnvZsCfA+l0BNuxpoa2rD4D8zFQumJXHHy2cEQr20jymZmqWyGSQkpwUGt/PyRh2v7vT0HaMrY3tbG1oZ1vjEX6+qYFHN9YDkJuRwsLSPC4sy+PiygIuLMsnK13RNJrMh/vMNbSR2XLg+0Ay8EN3/8aQ/eXAI0Ax0Abc6u6N4X1lwA+BWYADH3X3+lO9VnV1tevGI/FjYDDIloZ2frM7FO6vHzgCQGFWGlfMLebyOUVcVJ5PWUGmxnMTyMBgkNpAJ1sb2tnaeIQt+9vZefAoQYfkJGPBjFyqywtYXJlPdUUBRdnpsS55wjOz19y9eth9IwW9mSUDu4GrgUZgE3CLu78Z0eYJ4Ffu/mMzuwq4w90/E963HrjP3V80s2wg6O7dp3o9Bf3kd+hoD7/ZFeA3uwNs2BPgaM8ASQYXluWzbF4xS+dOY8GMXH1cl5N09PSzeX87m+rb+K+6NrY0tJ+YBju7OIuLywu4uLKAxRUFzCqYoo7BEKcL+mg+Hy0Gat19b/jJHgeuB96MaFMF/FV4eR3wVLhtFZDi7i8CuHvn+3oHMuH19A/y3BsHWVPTwMa3WgEoyU1n+Qems2zeNC47u0hDMXJaORmpXDG3mCvmhm572jswyBsHjrKpvo1NdW08t/2dsf7ZRVlcXVXCNQtKuGBWPsnqNJxWNEE/E2iIWG8ElgxpsxVYSWh45wYgx8wKgblAu5n9O1AJ/Bq4x90HIx9sZncCdwKUlZW9j7chseDuvH7gCGtqGvjlliY6egYoK8jki1fP5ZoFJcwryVGvS9639JRkLirP56LyfO5aejbBoLOnuZNX97by6x2HeOTlOv7pt3spyk7jI+eWcHVVCZedU0RGqmZgDRVN0A/3L3XoeM/dwINmdjvwW+AAMBB+/suBRcB+4OfA7cD/O+nJ3B8GHobQ0E3U1UtMtHX18dTmA6ypaWDnwQ7SU5L46Hln8cnqWSypLNCQjIyJpCRj3vQc5k3P4bZLKzja08/6XQFe2H6QX217m8c3NZCZlswVc4q5ZkEJV82fRl5mWqzLnhCiCfpGQidSjysFmiIbuHsTcCNAeBx+pbsfMbNGYHPEsM9TwAcZEvQy8Q0GnQ17AjxR08iLbx6ibzDI+aVTue+GD/Cx82eQm6FhGRlfuRmprDh/BivOn0HvwCCv7m3jxTcP8uKbh3hu+0GSk4zFFQVcXRXq7Y/0TeJ4Fs3J2BRCJ2M/TKinvgn4tLtvj2hTBLS5e9DM7gMG3f2r4RO5fwA+4u4BM/sRUOPuD53q9XQydmIZDDr/9loj31+7hwPtxyjISuOGRTO5qbqU+dPf/Y1LkVgLBkNDii+8eZAXth9iT3Po1OC5Z+VyTTj0F8zIjbthxTOadRN+go8C3yM0vfIRd7/PzFYTCu2nzewTwP2EhnR+C/y5u/eGH3s18B1CQ0CvAXe6e9+pXktBPzG4O+t3B/jGMzvZdaiDC2bl8WdXzObD55boS0syqdS1dJ3o6dfsO4w7zMybEjqZW1XCxZUFcfFt3TMO+vGkoI+9Nw4c4W+f2cHGt1opL8zkK8vnc90HpsddD0gST0tnLy/taOaFNw+xYU+A3oEgU6ekctX8aVy7oISr5k/ejoyCXqLS0NbNd17YxVNbmijISuMvrjqHTy8pn7S/+CKn0903wIY9Lbyw/RAv7TzE4e5+puWkc+sHy7llcRnFOZPrS1oKejmtI939PLS+lkdfrscM/vTySv5s6dk6wSoJY2AwyIY9LTy6sZ7f7A6QlpzEH51/FndcWsl5pVNjXV5UzvQLUxKnegcGeWzjPh5cV8vRnn4+cWEpX7xmLmdNnRLr0kTGVUpyElfOn8aV86fxVqCTxzbW8+Rrjfz7Hw5wUXk+t19awfIPTJ+0Y/nq0SegwaDzi80H+O6LuznQfoylc4u557r5nHuWZtGIHHe0p58naxr58Sv17GvtpiQ3nVuXlHPLkrIJee0dDd0IEJpJ8/z2gzzwwm5qmzs5b+ZUvrJ8Ph+aUxTr0kQmrGDQWb+7mR+9XM+GPS2kJSfx8UUzuPvaeae8YmcsaOgmwbk7v6tt4dvP72Jb4xHOmZbND269kGsXaCaNyEiSkoyr5odm5NQ2d/DoxnrWbGrk2TcO8pXl8/n04rIJ/21w9ejj3Gv7DvPt53fy6t42ZuZN4a+unssNi2bqIlAiZ2BvoJO/eeoNNr7VyqKyPO77+HnD3rJxPGnoJgHtePso33lhF7/e0UxRdhqrrjyHW5aU6ZZ7IqPEPXSu677/3EH7sX7+22UV/OVH5sbspikaukkg+1q7+LsXd/P01iay01P
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = torch.linspace(0.05,0.95,29)\n",
"accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]\n",
"plt.plot(xs,accs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Assemble the Data"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.BIWI_HEAD_POSE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#50) [Path('01'),Path('01.obj'),Path('02'),Path('02.obj'),Path('03'),Path('03.obj'),Path('04'),Path('04.obj'),Path('05'),Path('05.obj')...]"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"path.ls().sorted()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#1000) [Path('01/depth.cal'),Path('01/frame_00003_pose.txt'),Path('01/frame_00003_rgb.jpg'),Path('01/frame_00004_pose.txt'),Path('01/frame_00004_rgb.jpg'),Path('01/frame_00005_pose.txt'),Path('01/frame_00005_rgb.jpg'),Path('01/frame_00006_pose.txt'),Path('01/frame_00006_rgb.jpg'),Path('01/frame_00007_pose.txt')...]"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"(path/'01').ls().sorted()"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('13/frame_00349_pose.txt')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img_files = get_image_files(path)\n",
"def img2pose(x): return Path(f'{str(x)[:-7]}pose.txt')\n",
"img2pose(img_files[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(480, 640)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im = PILImage.create(img_files[0])\n",
"im.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAAB4CAIAAAD6wG44AABe9UlEQVR4nM39aaxl2XUmiH1r7X3OvfdNMeYUOWcymcwkk6Q4SKJUsqzS2C6pqlpuw3C7baAHN+A//tNA223YBRsFN2DDcBvucrddqEYZ1XbJrRooUUOVVJIocSZFMpOZJHOeIjIjMmN48eIN995z9l7r84+9z7n3RUQmSamq2geBiBv3nmGfvfaavjVs+Tt/5//+8Y9/nCQAFwJQQgCHYDhWnwCAa99z/C9lPFfEV1dQiJsOAhDc+v2tRzmHEIBlGFIfKxCSJB0wM6fnnNzcPbibWWdGd3d3M6fT3Zzm7kY4Sc9Og4EA3Qk6xWkEQYcTQJ0TIyAgOfwLkKSwniOQ9e9JA+gE6YJ6DgkCZALgJoQDYHlUhjOZ5Vnb7JzcdC93EsLK7cpAyDqe8kw6SQF8OAdexgzSHXASgJ4/fyH+2Cc+/lOf+akfPNf/bR8EHHDCjTRYtmxm2VL2vk8555QTmXO2nLIzWxYzz5bdDOZkds9mZurZMplJwtWNpJs5XACS5p7pBEHSzFI6DDGCNHdhY+6guzs9DxQlYfTxM0m6F0p3QP0NFCCSdEM2AHRPZtnEYOjMTp254yNPPdb3ybKXhezOYZkVApdlQVJAIYW0gboOEITTh6sEEFJiWU0jh5RP5T8ysOcaP4JeliWcMObxtRKUdApBJ6RwNMmYDYS7uxOidLhDIDQxJ0kz9jR3p9MGhjM6zLN15ubOnLN55+65nJFp5maWU7Ls7uaFSszmvZsTQiZzKywsZqQT7jRHItw9Ox0mJJwEPdIAkg7x7EXIwN1p68sslylxUlg/ExABK4UB1pkkCWiZM5DuBmhhcTdxN2N2T+YJCYvOZie2CqncnRRAK2nLdA6cSvpAXSkPLc9yL+Qv/xb5AoHGN14/PHvm3a7LZm5Myc3K+nMj6x1inwUghKSb5Cr+6Jnu7ubm5sxkdiS3npqHNwRZ3x4kEMgq9IbVRwLueXw39+S0YTEaCALuJjKIOhI0rljGyqJjFWW9u4OxTispoiiLHS5CKQzmJgIpIg6kU9kCEGHKfdf3sYkxNnSnUUREim4olIOTIlAJlQfEKyOQFAepKiKKgT0IUVFhJEihqAtUfeLSaFgGELB2GoEgDCqhTHXlL2G5/yCiV0cZVfm+fAZUle6VuiIav/H1L1+79o57gjA4ASur2K1cQJLwOnt1aVTNSEGRZUIQCIXoAIP0JJ1F1Ii7ZzM61QNAIkEc1HFwFigiObtAGzik6KpAiYUfBFImVoSgDwzjgEIFCjoBF9EilwRJCAFERRSqIztBpYUSwVQBgQi0nIQgAg1ilmbTuFgsQzsTBIEIEEIAEJBFVCQItQxKRSCCJgMiiICUp0gluYKiKhABlTAgE+amdKeTxgSDJ/O8c/Ys3LUoa4iNclMqLVWVdFKreBWwSE3AARQCi8MpUqSIExavXX5t1vaq0vddEyey4ieoSBHnWtZv/VtFVABVFU0iQaRRBJVUlq0IIBQiqpTXh6q2ASMfyBQQgapqua2Lxja65+VysdFOIBQIoMTaVWoiI0niSjNrEIFIIXAodyYhamW0w3SrSBCoKEWAQTENs1+eI0UexMCuT7GZaWhk+AmAM6gSMIC6prhUfZCK4gPbCTwn0r1pGlE1B2kQc5pl0IliQzhgKZm1zZbWB61z5E2MW36vBqqIQBWkuANlDAIoxKVYuyJxZ/PkHafumkymi8WinVBFNMS6PBWqQUQCJARFfdOoKgBURYQiUthE2K+GJSLlFjIcZbJDFSkCoaxEX59T27TzxSKEMGlUECEi4pWDKyXIeoVwWGhOEgqwrKrlso8hNs0UIGHjtXVFU0kh+mrqQwb7AygaT7hYHIlAhRFNbKaxmWhREiAAkwAUgenKNQJD6C4DI1Wh5wgBGqKoEBAFWNSqQhwwoxjhZrnrl31uN0RER2KO4rdMI+swIIBXVSEcXBgRkAZoWcdgz5FaO2cmJ87OptON2bLZmjZBQwgNoCKF2KqiqlJEhKCQG6qqoiJRlaJ2Y383xFNnzp6tukeDig5klSqkBFEaVYpS1FVFVFTV3V94/pmNjenZ02cvX979wGMPbm+fBgjJkLi2flUg7kKHax6MEZqjeAvm6dLFdzc2Nra2dsyKV8SVn0QjM0GypddFz2pYgXRxh5AQDUHAEBhiKxoFKwclSFkcSoZ1Di40YzHBpKwbhUJoZZGD5XsQAQQkQwlxJ9x78y6bcfBMpazWgXLlc1VTGKRDIf+ayzqYxJVxx/vERx96YGMSP/zkh15/48177j93//33QyRoDBpFNMYoqkGziIZQaCJlQYtAEEUMyNd2r8Rw4tTp0ytndW1AP/Bw950zp7rDyel7u0cfvifEeLuz6OZOmlk25mwpJZK5GNbZsuWtHTt16iSAnK0cxVwsDk/xiulWDS11upLVy3amQIYQtIkKMeQQY1At0r36vjLDwFZr/4JwgMIAAZEDBQxOepEUIYAIVCl+jhBiDnFVDao5KCLgoupQhxCEQgr5WARNU59TbHQpbi7BIigJSLGgpFJdQQABaOJDDz3+kY88dfbs2XPnHjx59mTTToZ76Q+mTLW28p13PgSEkdsGI+yHoy55x50Pbp7c3OPy1LnTIUbeYi6WQ1TFXVWVEoIUl1NEzVREIRJj0zTFGDZVUzVWtMNDCO5uZqBx+BZUFl+IDBAlVWIIjULMEENUDTIAMnUp1I9c/QMoQIpoJQqEpJIuEqRacJUFy+8qg36SYboEEoIUAVseSUH1k24zE1UTjjK8LIvx1EJ3UYHE2EzvuvschCdPn2radm1yVzdmNWNvD0ARkSyo1l/4iEaRoIM39Z5HmRVVIUOMxcGvh7urxhDKYldAB7jA3M2s2Fygl2VRVku5qvCxBneVELQRIChDaKSIV9anFMdnQB44EAjFDJUKXDVF64s4EFRVNRSlPJ5eXqQskGpWFWACQaDEwKiig7O0mpOV8YXqS0ldCWVabloTjBpUo7ibhML4txOsUh2rmw4KKqP/ZYgLUAqyoJTw/mcW+hazeSB2QQ8YQrEXQmGtECjSuFe6lqtJocDNVcUpChaTrfjp5ZwYI81FVDUAonXVce0+dSQrZiiojrPyqUDEzaiqg+EDiIkOdpgUHlZZHSpSMOIiOEfciWuA062zUWjKm08oRq4IgCjF05NQnJbb34i3V6cymJd/mUMAChowCCKcEt7/lqripGqZdC/+1sCUGoIWCayqpJUv3alagCmnA0HdrfpiBZkTFyeVDKIqQngIIQRzoTgoFc7gIBkxuOOFxsU7MIdAq4hTkUhmVagKnRQDHUKgkBYqFAkm4ioQBC1GX1kCuqYBKGoFngSEUA641sD9PvizVUi4K8uyEsSVxSW3p+K/+kOa5rZW1XtfIBgpV9bvyAhF1xaSlyXrXtixak2oirtIUDiFpLh78XQAhtiIKMRVJZuF0ErBF91XtF0Tg9WZKaPSanoN54hI5eCiIkWErA7rLe9EESk+uqxuM96/eFoVR1uJ6fHi0fqpiGb5Xaoz+yNN7r/0QwRt0xQlxNWy/UEXiQAIQUII1Y0TGZGT8Zvx+3K+qpbPWkGYY2eJBNWYkgmQc54fHakq1lyO8SjCcV3mrYY9OocixSstuDGKVoFiUNgVJijMVT7SCS9+9rE7UyFKCr1ch5ueO46kjhMqCIIAqILwYoD8xWn0lzwkpdR3nYh4CQr8wAtWi3I1m+s0vunLW7+RgcgDbVU1AkJn6lKZfg2hQDrD5Vi/2y0jud04a9Cl/FEUmFMCbjZVy4rRwZ6q0cXjN5eCLWPtdW73zNGur79GkG4eBtTph3Vu/mUehINQh9EixdFqMdffcyjig/UhsjZXIqIaSJj5SI91DV0CNZWboU4C1PLWQqUUKNihDq34JQIqfFjt1io2JaCKRAddBpur4ABVwKoQHjR6DRhK9YgkYAgKCYKIClS14rIcvFvSi0FOUZJUBwFb0wrQ6h2xqgGSWmP5Thgkx+I/1LHevGr+NR3FxXR3g1NLiO1HkyhF9ha7A1VMKmByTGD
2020-03-06 18:19:03 +00:00
"text/plain": [
2020-04-28 17:12:59 +00:00
"<PIL.Image.Image image mode=RGB size=160x120 at 0x7F4213FAD350>"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im.to_thumb(160)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)\n",
"def get_ctr(f):\n",
" ctr = np.genfromtxt(img2pose(f), skip_header=3)\n",
" c1 = ctr[0] * cal[0][0]/ctr[2] + cal[0][2]\n",
" c2 = ctr[1] * cal[1][1]/ctr[2] + cal[1][2]\n",
" return tensor([c1,c2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([384.6370, 259.4787])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_ctr(img_files[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"biwi = DataBlock(\n",
" blocks=(ImageBlock, PointBlock),\n",
" get_items=get_image_files,\n",
" get_y=get_ctr,\n",
" splitter=FuncSplitter(lambda o: o.parent.name=='13'),\n",
" batch_tfms=[*aug_transforms(size=(240,320)), \n",
" Normalize.from_stats(*imagenet_stats)]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAFUCAYAAABPx8fsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZBl113n+fmdc+597+XLtaqyVpVKm7V4kXdsQI1tjMHgHsaGwWyNu1k9wzLAdMQw9GaChiCmu5npmYCYnp6I7umYhomYhukB02AMBmxDY2PJi2RbkqUqqaTaqzIrl5dvufec85s/zrn3vSxJJQnLLv7IX0RWVr58ed+9Z/kt39/39zuiquzJnuzJnuzJnuzJM8Xc6BvYkz3Zkz3Zkz35myp7RnJP9mRP9mRP9uQ5ZM9I7sme7Mme7MmePIfsGck92ZM92ZM92ZPnkD0juSd7sid7sid78hyyZyT3ZE/2ZE/2ZE+eQ9z1fnn+3LqGEACDRsUWnl5nnqsb2+zft59f+MWf56d+8mdwrkuMkYMHl9naHNCdK1FVytICYK3FGINz1/24Z4iIMJlMKMvyr/2As7KzDcPJWeYXDnJ14wzHDt/G9vY2RVEQQmBSBcJogy89+iinnz7F9vaIulaqyZDB9oCra+tMxiM2trbp9IWlfSvE2EV1nv2rh7nt9uMgkbmusDi/j/e+9zv4jd/495y/cIUYlfW1i5w9/SQxToh+wvb2gF7Z4ejNB/E+snrwCM6ViDGAQUTAGozm/wMiiqogIvgYEaMQIs5aogZiFGKsiSEwHo+oqprB1gaHbrqd0kYUA1FQVaKCarqeqhBjBCDNOfzcP/wH8pIM/IuUGKNCzM+cviIQAY3KYGfIcDji5KOn2bi6AZrWisYRUcdErYg+UtfKeOKZDK8y3NlhMprQ6S3Q6Tq63YJebxFnC4qywFpLp2uxxuFsiZQGa7sYCjauPsbyykFiGNHrLHP4yFEWl5bo9udxzuX5EkKIxBiJQVHA1wEfIyFEQgicP3+GgwcPEbwnxECMimokakBjJMQI6on5d6ik96CgypmzT3Po0CrjcUWv10PyeKhqGiZAY6Qp61JVRKbzahA0KpNqG1eUiBgg3bNGIYYJIUbGoxH1ZIet7R1uu+tVhGpM8qdNXjcRVNJnAaqSftaQ15OmtRgKPvbRj/HWt92HSEQkrWNjHBCR7KKLSH49ff+Gt73jq77u/vWv/IJaa4A0dh2xqEAUEGzagyiI4mWCNRHRDqKKEgCPxA4Aqmm1pnnIDxk8IunvVdN7JK+ZEAPeK9YViLFUdU1UMCbt8RgjiiLGEjW9z4eAGMNfPvA5vvDwozz2xJPcfPQIb/26r+Oee25jaLs88IkHOXf+NB1GvObu21nedxNbVy+z/+BBfu9P/jMd4/mub30L9dikezcTgsklgaqIF1b2H8TaAiXkudY8j2nNO2fBZN1kDM5anHN0ul1ULMYkXRNckd6jYDCICsbujtGMzE67gAgxQlVVrF25yMblCxitqMKEYVA2L495zSvm6a/exaiCqI5ez2HE4L3HGJP0pbGIDWgArFDHCAKFGAgKLvDev/v+51xz17VatY+EIAiBalyzb3WOK2vrhKCMxjt8+MMf5md++u9T1zVGhNF4RIiKryK2SMrWGNNuHO891tpW4T+feO9f0PteqOyMxizMLxJ2LCtLi6gqc3NzWYkIxgK9ZW5/+Wu5+WWvZjK+wMXzl/jSYyfpLcyzb3U/k/GAjatb7OyMGWwP2NxYY3HFcOXiBdYuP0YIcOzELayuDvnn//O/ZGNri4Or+/nZn/1ZPvCBD7C0tEz0gQvnziDuHKoVjz9+muiVzY1tFhcW2X9wFecczpWgQsCACEYsIh5IzodRJapisEQVokZULaii6kHyl4IPI0qbNrGIMFsf++ylsjeufnZcCRtXhzz11Dk2NjYZbGxTe0+n41CdEGJNCDXVeIu6nrBxdZ3haMh4YOh2LQuLXcq+YWlhjrluwf7Fg1jnKIse1qaxc4XDljbvRcEaS+kKenNzHFhdpd9foNPrYazlTz90kVe++o1sXl3n2PFb0prOQ9Qs5QBgDdaSNEEybQgO8R4ChKCUZYcag4mNoRREIUpSCiFoUh4S0QjWJiMnGinLAmsd4/EW8/PzaIyIkWxoNU+kJGUMgLSGEoAIIg5rC6xxaV9iEImoRIQCocYYA66kcEMuXTrP4YOH8HVoLo8gaJyukOS47V4vIoIQue+++4gBrAMlG1PAqAENIIIg7Zp8obrhpZbSOBSf7zsZR0hTKSaCRIiKACUFxJD9Ek/UkOaLqh0UVYsxNhs7gxoIIeJDJIqg2OxcQTATorEEtYTgefz0Wf7oo3/O6bPnOHboIHffcRtv+fo3YlTbVSUKpXXc98Y38PWvex0hBqwYrEbqaPncQw9z9coanVCx2i/pF30e+NSn6M71+cKFAXOFcqjfIU5AMGA8Ig5DMiA+BsrSgTXUKmkMUIIoIpGOdpOh12RnrDEYBef6lEWJwRLUQFTA46J9UfPRrANf1xAnDDauYKRm5CfUVc2pU2c4dCSyuPxuVDoszBm+9NhjnLj1FqKJXFm7wk3HbmI8GWNNSdRtDA4QCrHJSItJqlSL697L9Y1kHQghbeSq9igl42rM0vwhhsMtfus//L+UZZfJeECn32V7a4S1BSEqWkF0ijEQo2CtZK/CJ+/7BWyGwWDA4uLiCx3X55UzF87y2nuP8cCDD/O6N70SiYoxJj0jyQvyqnQ6XURqRA5z9Ph+Vg7cRl3X1PUao+3A2TMn2d7eYDQcAbC1ucnO9oDBYMhga5tTO1d5yhXYTokizM8v85M/9VOMR2OstSyvLPFz//B/4F/8s3+Or4XRaJML504zHK2ztnWFM0+fxVeRI8cPMd/vs7C8RKfTAeNQiaABNYJVEI1EIhIlRxUVMYCqAS1Q9Smq8R4tSkSVKIrKTBSCpvcTUU3IwY2U/+VXf4m6rhgNBiwsLDGsJqzu289ir5+81F5Jx1kWl/dhjXDb8VtSJOIUk71abMHCwgK9uTnOPvU0b3zTm5hfXAQkeeXtp0lr6fSa527ec/e9r2ZuYRnXmSfptZiiOEmKXxBmVYA4IYakHEUVFUGMYWlxhbIoibFGYgFBMBIJKhgJRGMwIoTokyGSZPfEpO+HDx3l0uXzLCwstghLukeDQVFivoPGiMd2n6kqapQYa2zRxfsJ3U4X7z0Rl/apCxjvcM4RY6Ryls31y9x09CjB+/SsGcXAASFdt7WemHwPKQpGFOsMIURMdIiZRsXpWsnTjxozWpJm4UZItBUuuOT4mLRGEGnvJuSxlUjWmj3E1BC6BPUoEWMsxho++cBn+dBHPs7pc+d53Svu4bWvuoevee29adVZSyRHpSiiisSIeI8xBaUV7jlxjHt+4LsgR5CoQkxrU8QQ6kBhLOIjFpvWrUmLpNLARCxXL55nMrnK8aUuPgbGdLn1tuN88fRFxpMh9x6Z45V33cnALFIwQWNayZGIC4alziLF/AKqihHFxGRInEgyqjY57YjBChQ2oTHdTrddc9YYVCPGOqIGRGx+BkHMdAfORpCqCXms65oYA6dOnSJMhthYESeeELv8xX/+MD/4Iz+JMx1ssZidE2Wh36frSqrK03Vd8IqqRUyNkxKygywmRfCtSH3dtXFdI1lVFSDJWPrIZDzGGAVxbGytY21N4W5BGdDtdrl0cYN+v6Cua4rCEapIUabFkAYgfQ8hPG9E2cBEL6VnefDwYcDgEaIBG/IvrOAQKq+IEaJXjLVICFjrKMvk4Sr76PTHHLvlHqLf5pEHHwYrLBclywf2MRlPGA2HbG1uMh6OWb+0TtEFUw05uXWFGA2uKHnZnffwP/7SrwJKf9FSq+H2O1+BryPnzp7BcA6jFVub61w4d5mlxXn6C31WVw9QdEuKokyOLRYiqA2YhIXkwQONDfQliASqqsL054lhZkByxINKVlYeWq/+JRv2Fy3HD91ELAyuKFjp9fFYjAHnBKMRYw1RlcWlBXq9OY4cPczi4iJl2ZuGdu0
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 576x432 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dls = biwi.dataloaders(path)\n",
"dls.show_batch(max_n=9, figsize=(8,6))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3, 240, 320]), torch.Size([64, 1, 2]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = dls.one_batch()\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"tensor([[-0.0753, 0.0237]], device='cuda:5')"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Training a Model"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18, y_range=(-1,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sigmoid_range(x, lo, hi): return torch.sigmoid(x) * (hi-lo) + lo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD7CAYAAABwggP9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9b3/8dcHQiBkAQJh32QTQfaIVkS5vW3d6oWK9QoU3CpVq3a7Vm9vrdb6s9Xb5bbWekurVVzQ1qVq1Wqt2gpVNCwBIhAim6wJgewryef3xwy9Y0wIgUnOZPJ+Ph7zgPme7xzexpn55Jzv95yvuTsiIiINdQo6gIiIxCYVCBERaZQKhIiINEoFQkREGqUCISIijUoIOkA09enTx4cPHx50DBGRdmPVqlUH3D2jsW1xVSCGDx9OVlZW0DFERNoNM9vR1DadYhIRkUZFtUCY2Q1mlmVm1Wb2cDN9v2Fm+8ys2MweMrOuEduGm9mbZlZhZpvM7DPRzCkiIs2L9hHEHuAu4KGjdTKzc4FbgX8FhgMjgO9HdFkGrAF6A/8FPG1mjZ4jExGR1hHVAuHuz7r7H4HCZrpeDjzo7jnufgj4AXAFgJmNAaYCt7t7pbs/A6wH5kYzq4iIHF1QYxDjgeyI59lAPzPrHd621d1LG2wf34b5REQ6vKAKRApQHPH8yN9TG9l2ZHtqYzsys8XhcY+sgoKCqAcVEemogioQZUBaxPMjfy9tZNuR7aU0wt2XuHumu2dmZGiYQkQkWoK6DiIHmAT8Pvx8ErDf3QvNLAcYYWapEaeZJgFPBJBTRCSmVNQcZl9xFftLqskvrSK/pJo6d649Z2TU/62oFggzSwjvszPQ2cy6AYfd/XCDrkuBh83scWAv8F3gYQB3zzWztcDtZvZd4HxgIhqkFpE45+4UlFbz0aFKdh2qYNehSvYUhR57i6vYU1RJSVXDr1PISO0a+wWC0Bf97RHPvwR838weAj4Axrn7Tnf/s5ndC7wJJAHPNHjdZYQKxiFgJ3CJu2uAQUTaPXfnQFkNefllbDtQzvbCcrYdKGdnYQU7D1ZQWVv3sf69undhYM8khqR3Z/pJ6fTv0Y3+ad3ol9aNfmldyUjtRlq31jkZZPG0olxmZqbrVhsiEisKSqvZtK+EzftK2bK/jNz8Uj7ML/vYUUBiQieGpXdnWO9khvXuztD00GNwryQG9Uqie2LrjgSY2Sp3z2xsW1zdi0lEJAjuzs6DFWzYXcKGPcVs2F3Mxr2lHCir/mef3smJjO6Xwr9NHsiojBRG9k3hpD7JDOyRRKdOFmD6pqlAiIi0UElVLWt2FrF6xyHWflRE9q4iiipqAejS2RjdN5VZJ2dwyoA0ThmQysn9Uumd0rWZvcYeFQgRkWbkl1axcutBVm4r5P1th8jNL8UdOhmM6ZfKueP6M2lITyYM6sGY/il0TegcdOSoUIEQEWmgtKqWdz4sZEXeAZbnHeDDgnIAkhM7M3VYLy6YMIDM4b2YNKQnKV3j92s0fv/LRESOkbuzeX8pb2zK563NBazecYjD9U5Sl85MPymdSzOHcMaI3owfmEZC546zSoIKhIh0SIfr6nlv+0Fey9nP6xv3s+tQJQDjBqSx+OwRnD0mg6lDe5GY0HEKQkMqECLSYRyuq+edrYW8vH4vr+bs52B5DV0TOjFzdB9u+JdR/MvYvvRL6xZ0zJihAiEicc3dyd5VzB/X7OZP6/ZyoKya5MTOfPqUflxwan/OOTmj1a81aK/0UxGRuLS/pIrn1uzm6VW7yMsvIzGhE/86ti+zJw9k1sl96dYlPmYatSYVCBGJG3X1zt+3FPDEyp28sSmfunonc1gvfnTxBC6YOIC0bl2CjtiuqECISLtXXFHLU1k7WfrODnYdqqR3ciLXzBzBpZmDGZGREnS8dksFQkTarW0Hyvnt21t5dvVuKmvrmD48nVvOG8u54/t36NlH0aICISLtzpqdh/j137by6gf76NK5E3MmD+TyM4czfmCPoKPFFRUIEWk3Vm4t5BdvbGFFXiFp3RL46qxRXH7mcDJS2999jtoDFQgRiXkrtxby07/ksnLbQfqkdOW/LjiFeacPjevbXMQC/XRFJGat31XMf7+2mb/nFtA3tSu3XzSOedOHaopqG4n2kqPpwIPA54ADwH+6+yfWkjazV4CZEU2JwGZ3nxDevh3oBxxZWukf7v65aGYVkdj10cEK7n11My9m76Fn9y5854KxLPrUcBWGNhbtI4j7gRpCX+6TgZfMLNvdcyI7ufv5kc/N7C3gjQb7usjdX49yPhGJYSVVtfzyjTweXrGdTp3gxk+PYvHZI0jV9QuBiFqBMLNkYC5wqruXAcvN7AVgIXDrUV43nNDRxJXRyiIi7Ut9vfPM6l3c8+dNFJbXMHfqYL71uTEM6JEUdLQOLZpHEGOAOnfPjWjLBs5p5nWLgLfdfVuD9sfNrBOwBrjZ3bMbe7GZLQYWAwwdOvS4gotIcDbsLua25zewZmcRU4b25HdXTGfCYE1XjQXRLBApQHGDtmIgtZnXLQLuatC2AFgNGPA14FUzG+vuRQ1f7O5LgCUAmZmZfhy5RSQA5dWH+dlfcnloxTbSkxP58RcncfGUQTG7PnNHFM0CUQakNWhLA0qbeoGZnQX0B56ObHf3FRFPf2hmlxM6DfVidKKKSJDe2pzPfz23gd1FlcybPpRbzxtLj+4aZ4g10SwQuUCCmY129y3htklAzlFecznwbHjM4mic0NGEiLRjxZW13PWnD/jDql2M6pvCH679FKcNTw86ljQhagXC3cvN7FngTjP7MqFZTLOBMxvrb2ZJwBeBixu0DwWGAO8DnYAbgT7Aiob7EJH242+5Bdzy9DryS6u4ftZIvvaZ0XRN0LTVWBbtaa7XAw8B+UAhcJ2755jZTOAVd4+8reIcQmMUbzbYRyrwADASqALWAue7e2GUs4pIG6iqreOHL2/kkXd2MLpvCr9eOINJQ3oGHUuOgbnHz7huZmamZ2VlBR1DRMJy9hTz9SfXsiW/jCtnDOeW88bqYrcYY2ar3D2zsW261YaIRJ2789i7O/jBnzbSs3sXll41nbPHZAQdS1pIBUJEoqqkqpZbn1nHy+v3cc6YDH566SR6p+huq+2RCoSIRM3GvSVc+9gqdh2q5Nbzx7J45ghd19COqUCISFT8cc1ubn12HWnduvDU4jPI1PTVdk8FQkROyOG6ev7fyxv53YrtTB+ezi8XTKFvaregY0kUqECIyHErrqjlq0+sZnneAa6cMZzvXHAKXTprLeh4oQIhIsclL7+Ma5ZmsetQBfdeMpFLM4cEHUmiTAVCRFrsHx8e4NpHV9GlcyeWXaPxhnilAiEiLfLMql3c+uw6hvdO5qErTmNIevegI0krUYEQkWPi7vzir3n87PVczhzZmwe+NI0eSboDazxTgRCRZtXVO7c9v4EnVu7k4qmD+NHFE0lM0GB0vFOBEJGjqqqt4xtPreWVDfu4btZIvn3uyZjp4reOQAVCRJpUVn2Yax7J4p2thdz2+XFcfdZJQUeSNqQCISKNKqqo4fLfvc+G3cX87N8n8YUpg4OOJG1MBUJEPqGgtJqFD65ka0E5DyyYyufG9w86kgRABUJEPmZfcRXzf/sue4oqefCKTGaO1m26O6qoTkMws3Qze87Mys1sh5nNb6LfHWZWa2ZlEY8REdsnm9kqM6sI/zk5mjlFpHF7iyu5bMk77C+uYulVp6s4dHDRnqd2P1AD9AMWAA+Y2fgm+j7l7ikRj60AZpYIPA88BvQCHgGeD7eLSCvZU1TJZUve5UBZDUuvns70k3R1dEcXtQJhZsnAXOA2dy9z9+XAC8DCFu5qFqFTX//j7tXu/gvAgE9HK6uIfNze4krm/eZdDoaLw7RhKg4S3SOIMUCdu+dGtGUDTR1BXGRmB80sx8yui2gfD6zzjy+Wva6p/ZjZYjPLMrOsgoKCE8kv0iHll1Qx/zcrKQwXh6lDewUdSWJENAtEClDcoK0YSG2k7++BU4AM4Brge2Y27zj2g7svcfdMd8/MyND5UpG
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(partial(sigmoid_range,lo=-1,hi=1), min=-4, max=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FlattenedLoss of MSELoss()"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.loss_func"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2020-04-28 17:12:59 +00:00
"text/plain": [
"SuggestedLRs(lr_min=0.005754399299621582, lr_steep=0.03981071710586548)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEQCAYAAABIqvhxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9bnH8c+TfU8IWdgJ+75JRNkEwaUu1N0rVq2t1lttq1Zrb7Wveq3V21Vbt+LSWq1W6waKWjdEdhHZggbCTlhCNkJWspDMc/+YCQ4hCRmS2ZLn/XrNK5Nzzpz5zkDmmd/vd875iapijDHGeCLE3wGMMcYEHysexhhjPGbFwxhjjMeseBhjjPGYFQ9jjDEes+JhjDHGY2H+DuArKSkpmpGR4e8YxhgTNNatW1esqqnNresyxSMjI4O1a9f6O4YxxgQNEcltaZ11WxljjPGYFQ9jjDEes+JhjDHGY1Y8jDHGeMyKhzHGGI9Z8TDGGOMxKx7GGNNJFZbXUF5z1Cv7tuJhjDGd1OOLtzPtd4s52uDo8H1b8TDGmE5IVVm0uZDJg7oTHtrxH/VWPIwxphP6+kA5+eU1nDuyh1f2b8XDGGM6oU+2FBAiMGt4mlf277PiISLJIrJARKpEJFdErm1l29NEZJmIVIpIgYjc4bYuQ0Q+E5EjIpIjIuf45hUYY0zw+GRzAZn9k0mOjfDK/n3Z8ngKqAPSge8A80RkVNONRCQF+BB4BugODAY+dtvkVWCDa90vgTdFpNmrPhpjTFe0//ARthws55yR3ml1gI+Kh4jEAlcAv1LVSlVdASwErm9m87uAj1T1X6paq6oVqrrFtZ+hwGnA/6pqtaq+BXzl2rcxxhhg0eYCAK+Nd4DvWh5DgQZV3ea2LAs4oeUBnAmUiMgqESkUkXdFpJ9r3Shgl6pWtGE/xhjTJS3aUsig1FgGpMR67Tl8VTzigLImy8qA+Ga27QN8F7gD6AfsxtlV5el+EJFbRGStiKwtKio6xejGGBM8yqqPsnrXIa+2OsB3xaMSSGiyLAGoaGbbamCBqn6pqjXAr4EpIpLo4X5Q1WdVNVNVM1NTbVjEGNP5Ld1WRL1DOdeL4x3gu+KxDQgTkSFuy8YB2c1suwlQt98b74tr+4Ei4t7SaGk/xhjT5XyyuYCUuAjG9+3m1efxSfFQ1SpgPvCgiMSKyFTgEuClZjb/B3CZiIwXkXDgV8AKVS11jZlsBP5XRKJE5DJgLPCWL16HMcYEsqMNDpZsLWTW8DRCQ8Srz+XLQ3VvA6KBQpxjGLeqaraITBeRysaNVHUxcB/wvmvbwYD7OSHXAJnAYeB3wJWqagMaxpgu78vdJVTU1HPOiHSvP1eY15/BRVVLgEubWb4c50C4+7J5wLwW9rMHmNnxCY0xJrh9sqWAiLAQpg1J8fpz2eVJjDGmE1BVFm0pYNrgFGIivN8usOJhjDGdwPbCSvaVVDN7hHePsmpkxcMYYzqBRVucZ5XPHu798Q6w4mGMMZ3Cos0FjOmdSI/EKJ88nxUPY4wJcsWVtWzYV+qzLiuw4mGMMUFvcU4hqvjkEN1GVjyMMSbIfbqlgJ6JUYzq1fTqTd5jxcMYY4JYzdEGlm8vZvaINES8e1a5OysexhgTxL7YXcKRugafHWXVyIqHMcYEsc9yCokKD2HyoO4+fV4rHsYYE6RUlcU5hUwZlEJUeKhPn9uKhzHGBKmdRVXsLTnC2cN9d4huIysexhgTpD7LKQRglhUPY4wxbbU4p5Bh6fH0Tor2+XNb8TDGmCBUXnOUL/eUMHO4f6bYtuJhjDFBaOX2Yuodyqxhvu+yAisexhgTlBbnFJIQFcbE/t6dq7wlVjyMMSbIOBzKZ1uLOGtoKmGh/vkYt+JhjDFB5uu8Moora/1ylFUjKx7GGBNklmwtQgTOGuqfwXKw4mGMMUFn6bYixvROJCUu0m8ZrHgYY0wQKTtylA17DzPTj60OsOJhjDFBZcWOYhwKM4ZZ8TDGGNNGS7Y6D9Ed1yfJrzmseBhjTJBQVZZuK2K6Hw/RbWTFwxhjgkROfgWFFbXM8PN4B1jxMMaYoLFkaxGAFQ9jjDFtt3RbIcN7xJOeEOXvKFY8jDEmGFTW1rN2z2Fm+ulCiE1Z8TDGmCCwaofzKrqB0GUFVjyMMSYorNxRTExEqN+uotuUFQ9jjAkCG/aVMq5PEhFhgfGxHRgpjDHGtKjmaAOb88oZ38+/Jwa6s+JhjDEBLjuvjHqHMr6vFQ9jjDFttGFvKQATumLxEJFkEVkgIlUikisi17aw3QMiclREKt1uA93Wq2sfjev+5qvXYIwx/rBhXym9k6JJC4DzOxqF+fC5ngLqgHRgPPC+iGSpanYz276mqte1sq9xqrrDGyGNMSbQbNxbGlDjHeCjloeIxAJXAL9S1UpVXQEsBK73xfMbY0ywKqyo4UBpdUB1WYHvuq2GAg2qus1tWRYwqoXt54hIiYhki8itzaxfJiL5IjJfRDI6OKsxxgSMjY3jHV2x5QHEAWVNlpUB8c1s+zowAkgFfgDcLyJz3dbPADKA4UAe8J6INNv9JiK3iMhaEVlbVFTUvldgjDF+sHFfKWEhwqheif6OchxfFY9KIKHJsgSgoumGqrpZVfNUtUFVVwGPAVe6rV+mqnWqWgrcAQzAWWxOoKrPqmqmqmampgbGKf3GGOOJDXtLGdkrgajwUH9HOY6visc2IExEhrgtGwc0N1jelALSjvXGGBOUGhzKpv2lAXV+RyOfFA9VrQLmAw+KSKyITAUuAV5quq2IXCIi3cRpEnA78I5r3SgRGS8ioSISBzwCHAC2+OJ1GGOML20vrKCqriHgxjvAtycJ3gZEA4XAq8CtqpotItNFpNJtu2uAHTi7tP4J/F5VX3StSwdeA8qBXTjHPi5W1aO+eQnGGOM7jYPl4/sGxsUQ3fnsPA9VLQEubWb5cpwD6o2/z226jdu6xcAwrwQ0xpgAs3FfKUkx4WR0j/F3lBPY5UmMMSZArc09zPi+SYgE3rCuFQ9jjAlABeU17CisZMqg7v6O0iwrHsYYE4BW7SwGYMqgFD8naZ4VD2OMCUCrdhwiKSackT2bniIXGKx4GGNMgFFVVu08xOSB3QkJCbzxDrDiYYwxASf30BEOlFYzZXBgdlmBFQ9jjAk4K4+NdwTmYDlY8TDGmICzascheiREMTAl1t9RWmTFwxhjAojDoXy+6xBTBncPyPM7GlnxMMaYAJKTX0FJVV3AHqLbyIqHMcYEkMbzO6YODtzxDrDiYYwxAWXljmIGpsTSMzHa31FaZcXDGGMCRF29gzW7S5gS4K0OsOJhjDEBY9XOYqrqGpg5NM3fUU7KiocxxgSIj7LziY0IZdqQwB4sBysexhgTEBocysfZBZw9PC3g5itvjhUPY4wJAGv3lHCoqo4LRvf0d5Q2seJhjDEB4MPsfCLCQpg5LNXfUdrEiocxxviZqvLR1/mcNSSV2EifzQ7eLlY8jDHGz746UEZeWQ3fGt3D31HazIqHMcb42Qdf5xMWIpwzIvAP0W1kxcMYY/xIVfnw63wmD+pOUkyEv+O0mRUPY4zxo+2FlewuruL8UcHTZQVWPIwxxq8+2VwAwLkj0/2cxDNWPIwxxo8+2VzAuD6JpCdE+TuKR6x4GGOMnxRW1LBxX2nQtTrAiocxxvjN4i2FAJxjxcMYY0xbLdpSQJ9u0QxLj/d3FI9Z8TDGGD+ormtg+fZizhmRHtBzlbfEiocxxvjBih3F1NY7gnK8A6x4GGOMXyzaXEB8VBiTBiT7O8opaXPxEJG7RGS86/6ZIrJXRHaJyGTvxTPGmM7H4VA+zSlg5rA0wkOD8zu8J6l/Cux23f8t8CjwMPCXjg5ljDGd2cb9pRRX1gXVtaya8uTav4mqWiYi8cA44BxVbRCRR7yUzRhjOqWPswsIC5GgmKu8JZ4Uj30iMgUYBSxzFY4EoME70YwxpvNRVf7z1UGmDE4hMSbc33FOmSfdVvcAbwK/BH7jWnYxsKYtDxaRZBFZICJVIpIrIte2sN0DInJURCrdbgPd1o8XkXU
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.049488</td>\n",
" <td>0.022839</td>\n",
" <td>00:39</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </thead>\n",
" <tbody>\n",
2020-03-06 18:19:03 +00:00
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>0</td>\n",
" <td>0.008415</td>\n",
" <td>0.005187</td>\n",
" <td>00:54</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>1</td>\n",
" <td>0.003400</td>\n",
" <td>0.000343</td>\n",
" <td>00:55</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>2</td>\n",
" <td>0.001462</td>\n",
" <td>0.000100</td>\n",
" <td>00:55</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"lr = 1e-2\n",
"learn.fine_tune(3, lr)"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.01"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.sqrt(0.0001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAGzCAYAAACMxsRFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9ebBty13f9/l1r7WHM9x77n3z03tPM3pIQkxi1ICEBHJkEw9F7LIpk0o5phyIXR6CIdhxAimDHcr8kcIhxnZiOxAbbIwDxiAgGBmBZFkS6AkJ0Cw97hvvcO49w957re7fL39091pr77PPuc/E7x5VsX919z17r9Wrx2+v/k39azEzNrShDW1oQ3eG3HlXYEMb2tCGfj/R5qW7oQ1taEN3kDYv3Q1taEMbuoO0eeluaEMb2tAdpM1Ld0Mb2tCG7iBtXrob2tCGNnQHqTrrpprZjf0bXNq7BAIy9C4zQJZS54uFZCmB4dDBHScM0isQlws3P8gq5LRutdCct6FLZfuVBFVujxKjYQZN02BmNPOIGrRtIMZADNC2LapKjBEk0jQNbdtipph6zIwQAqaKV0UtIAINDVEDqiFVWx2qhqqmT2tghlokhJYYI6qKc44QWm7eeIrpdEqMkRgNoaZtW0IIhHCMWUpbngNQVcwMEUFEiDEiAuLafD9iBiJgFhHxmEHbRmKcEzUiYqn7zdOa461vfV3ujwjmEKm68gBijJhZul+62AwnFeAwi0t1K99VFRHBzBiNJhwdHfMt3/YXTg7oHaANtjfYPi9sn/nSDRbYvrBDsIAgeDyKJChkYAmSQBt81wFSirL+exOUGAOj0YgQIwtxeA+VA3AEdd2zIcB2qUMAJyNUIUbQmPJtmvS9aWHRNjQx0DQNi8UiD2zsBnW+OCTGyGI+T7+bfcwSYNp52w1oCAENqSNFJIOzAQlYwhTiEkBjVEIIVEgGDETX5HYrhuLEY2rdwNUORBRQRKANEe99KlcVT8XN69YNZpqQ+SttB8D0l76eMaZp6xwOiCEi1B0gAIK2CDUhHANGCEbUhhBagi6QYLStcCsEnPOIRLyrMbM8rpK/G96nySnOJTSY5UmbXkaqedwTVFJ7BMQJmAMz5vN5+n5OtMH2Btvnhe0zX7q/9aF9RqMRN2/epGka9o9u0rQtx/MZKqDNolsZfBNwHSJBqAghdJ1XeYGoHQgqV2MExLXE2CKT4y5toimQVmZE83hLXtVaDEU1plUQwTkjxrT6RF3kDkrPxTZ1ove+u5ZWd6NK/ZRXTcO0wUw7MIl4VGdoFCxzKGVlq6qKEPOzISKmoAmUoGD9qmlm3Fr0z8/nxxwcHbK3t8d4PKZtW2Qw6CCIVF16h+Cco6rKkKW+rCqPmVDVQgxGVVXUTrp+NGeYRWpfIYyweoRa4oQ0bqGqBFtQKYSoTLwkTszq9NJB6Dk9Q8Qw8sRQy8zZMmvoRFIaJE1Q51BTQDJIBbEqcSHnRBtsb7B9Xtg+86X7q7/6s4Qwx2WJptJUsagtVeVom9QRMUbavFKpal7henY7hICIomp5FU6dHULoxRBpu+dDCEk8ahpmsxmEKq243kACTqouLcAiizxtG9nZ2cHFgPiI4DF1mPPdSpoGXqnqBHJRwznXDaSvJxiZ5UBxrkoAMYdzDostIo66GuHM4Z2heJxPAPfeZVHHOnHJe5+eNUmDIUaMgRc8eA9PPf0s4scYQp2B6L1PICDgXZXqUAY911WqFnCd2FNEUJGY+9136c0MwWEELHMVph4NEdRYWMRZSxtaqCY4gcpZ4jLMEXPpPSeSy3ICmVsoYq+IoM4l8QswVyXYJraBnBgTBVkRu+8gbbC9wfZ5YfvMl+773v3LOG9cvHiRa9eusbN9AXFpdSurksur7LgeYWaMx+NcwTQYQE4bqKoKM4+TmrEL7GzXqBYd1ahrWOmAbvVz/W8QmqahqqpOZ+TqEWrG0dEtxpMxI8lA0ZxXBnLJ11dpNXMehHrQ4YKawzkQB2kVdGCCcxUhtDifJlxpm4ghUoOllW+4wqmGDkhmhsMjztCsGxtVwkOP7IJUVNUI09CBGBMiHiMiEnADXZ6ZUUmvh8KEYGnFdj6LXREWiwXTaeKqNIuKURvACK2BGu2iQUwgtogLROqliZolbLqVfKDPTGs+g/oWLka6CZF0YIaZ5HspjXMlv/OhDbY32D4vbJ/50n300VdS1zX33H0vN27cYjSeYypUVQKR99INeO18xxnUdZ1XPun0JN4PDA8maJxnoCbAWe6IMuDOpdU3KagN53xWpEserL7sw9kx3jmuXxcuXbrEdOxxMkrAE0OpsuJb6UUKV8Z0acCjOYpeSpyxmLeMRmOSEt0w0W5CmGleuQVTwWi7QeuUPuWXCN4SRJpmjlGBGZXzmNRUdZ3WdDOKSOOKc8lAtC05YyOwiGKIKI4aI6KmqBm1rxiN0jhFszS5AXGjpItzgRBaDEnGj2bOoom4UZVX+KFYNWhL7nPVXj9omsqUlXqmPELua4fksj4XaIPtDbZPtOUOYfvMl+49L9imbSOjHdh1NTv1NBkQqHNhoa9MbkzlfXrTq1taNSs/oaocSMOtg2vsXrqHe+65h6qqUhqp8d53YKuqqnt+Wk0A2N4dIU4Zj3pRRc14/PGP8+EPf4QXPfxCjo7mvOWtbwKtEGeIRGIWR4pF0izVTxXMBYJGFotF1lEBWacWteXatX3GozEiaWK0UTqxMVmADdU2r7JjYjSceGK0ztKLJeuyhYBzNSoNpkrlXdL/+BpxPg8eeQIZlbSJyzCf7hSci2HiMRHEYgKOI1mfSZxY6buyKkcRjCQGmijke0FAbYFakya+TTpASrEY5HL7FT7/Fpf1WC7XXLBhPSHrv8jXXNaJSTZErTXu3hHaYHuD7fPC9pkv3W9429v56Z/+Gf7w29/Oj//Yv+DN//nb2Nvb63RbWDUAZpsXjYFZN1csVSnpctRaQphT15c6sWeJlvXWJ3+vkDO4eO+L+PoXvZSPfPAZ3vDmz6OayFLexY5Y57/F3UNVmS8i3ozKmrTKV0nE1BAwDYhbMJrsJLHGRXCxM1qUPFRrYhxjGqgkXXPeUCliWtb/+ZbKjCieOPGIGuDyZPY4TX1XLKniUo2HrildmyQbR7Jog0ScA9MRkQjeYRYw59NarJIU/kJyA0JoM3g9dbrmQKoakwoloGhnTEiiU+YArMa7Uh/QzMFYGX8cgmXj004Gb2qTIVn95VlynbrDtMH2Btvnhe0zX7oX9+7jT3/zn0Ej/Mk/9c34SRKJKl91mOnxU3XyTFe/4f3coaqC867r8CVgFhCu/j2LBIQpVT3i4qUZro6dov006kU7oa7r7nux3MYYu5W0qkaMx8W/MHaGkGK1TTqsNFFN3dI98F2eZkbtayqFuXiqyhMWDZWv8b5Kz0clC0oUV5TST0WUXe4TG7wDHIZhrp+UQ26AgS5KnCDWi7CrfTP8C8l9qUwOAVT6KiRDRhpvl8U8ioiqRpbUuheUSOZ4Sr3OiTbY3mAbzgfbZ750q9EYM/B1WQlcViwzUEJnWuIATqhqOhA7X+VibanxSx1+O25geM3AVTUmDqlqKIpyOz09QjcoS3UoSbIuznvfWWiH4CjXgOToXXRBaAfSocN0J7pI0mQ55xFfY87wPhsazPCuBwBm2eDRcwFJX1Vuu6UxMCu+oEqMLVDjfbYGQxK7lgbDurZ2LjjZYi7ik1+iSQZWMhKUqokV1xrpdYC5fsn4IP04ZlHL5SlXSo4hZk7mfGiD7Q22zwvbZ75029gmgwArQFsDoBNFDFd04LY+mevqOHj+1LQCXsh+gBWxV9GfmXfnYsNQvLETQJ1MJgP/wkSFm0jPjHodGGQLroIo4l0HaDMDjZgk55adeoS0Aec9SNH3BYr+UIm4gkznIPaiHwBeEIOoARC8FJQmfSTZh7CISELEssgkCM5qnAPnlCABqWu8RsajasAdJQOBuMKR5B1CLmIKnUEJh1rqycQVlF0+gvcQQkTEkUTwrPdyltiKc6INtjfYPi9sn/nS7V1HZFWddftVevj
"text/plain": [
"<Figure size 432x576 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.show_results(ds_idx=1, max_n=3, figsize=(6,8))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. How could multi-label classification improve the usability of the bear classifier?\n",
2020-03-18 00:34:07 +00:00
"1. How do we encode the dependent variable in a multi-label classification problem?\n",
"1. How do you access the rows and columns of a DataFrame as if it was a matrix?\n",
"1. How do you get a column by name from a DataFrame?\n",
2020-05-14 12:18:31 +00:00
"1. What is the difference between a `Dataset` and `DataLoader`?\n",
"1. What does a `Datasets` object normally contain?\n",
"1. What does a `DataLoaders` object normally contain?\n",
"1. What does `lambda` do in Python?\n",
"1. What are the methods to customize how the independent and dependent variables are created with the data block API?\n",
2020-03-18 00:34:07 +00:00
"1. Why is softmax not an appropriate output activation function when using a one hot encoded target?\n",
2020-05-14 12:18:31 +00:00
"1. Why is `nll_loss` not an appropriate loss function when using a one-hot-encoded target?\n",
2020-03-18 00:34:07 +00:00
"1. What is the difference between `nn.BCELoss` and `nn.BCEWithLogitsLoss`?\n",
"1. Why can't we use regular accuracy in a multi-label problem?\n",
2020-05-14 12:18:31 +00:00
"1. When is it okay to tune a hyperparameter on the validation set?\n",
"1. How is `y_range` implemented in fastai? (See if you can implement it yourself and test it without peeking!)\n",
2020-03-18 00:34:07 +00:00
"1. What is a regression problem? What loss function should you use for such a problem?\n",
"1. What do you need to do to make sure the fastai library applies the same data augmentation to your inputs images and your target point coordinates?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. Read a tutorial about Pandas DataFrames and experiment with a few methods that look interesting to you. See the book's website for recommended tutorials.\n",
"1. Retrain the bear classifier using multi-label classification. See if you can make it work effectively with images that don't contain any bears, including showing that information in the web application. Try an image with two different kinds of bears. Check whether the accuracy on the single-label dataset is impacted using multi-label classification."
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}