fastbook/06_multicat.ipynb

2180 lines
652 KiB
Plaintext
Raw Normal View History

2020-02-28 19:44:06 +00:00
{
"cells": [
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
2020-08-18 14:48:52 +00:00
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-08-18 14:48:52 +00:00
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from fastbook import *"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"[[chapter_multicat]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"# Other Computer Vision Problems"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-08-23 17:01:35 +00:00
"In the previous chapter you learned some important practical techniques for training models in practice. Considerations like selecting learning rates and the number of epochs are very important to getting good results.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-19 01:18:45 +00:00
"In this chapter we are going to look at two other types of computer vision problems: multi-label classification and regression. The first one is when you want to predict more than one label per image (or sometimes none at all), and the second is when your labels are one or several numbers—a quantity instead of a category.\n",
2020-03-03 23:04:23 +00:00
"\n",
"In the process will study more deeply the output activations, targets, and loss functions in deep learning models."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## Multi-Label Classification"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Multi-label classification refers to the problem of identifying the categories of objects in images that may not contain exactly one type of object. There may be more than one kind of object, or there may be no objects at all in the classes that you are looking for.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"For instance, this would have been a great approach for our bear classifier. One problem with the bear classifier that we rolled out in <<chapter_production>> was that if a user uploaded something that wasn't any kind of bear, the model would still say it was either a grizzly, black, or teddy bear—it had no ability to predict \"not a bear at all.\" In fact, after we have completed this chapter, it would be a great exercise for you to go back to your image classifier application, and try to retrain it using the multi-label technique, then test it by passing in an image that is not of any of your recognized classes.\n",
2020-03-03 23:04:23 +00:00
"\n",
2020-05-19 01:18:45 +00:00
"In practice, we have not seen many examples of people training multi-label classifiers for this purpose—but we very often see both users and developers complaining about this problem. It appears that this simple solution is not at all widely understood or appreciated! Because in practice it is probably more common to have some images with zero matches or more than one match, we should probably expect in practice that multi-label classifiers are more widely applicable than single-label classifiers.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"First, let's see what a multi-label dataset looks like, then we'll explain how to get it ready for our model. You'll see that the architecture of the model does not change from the last chapter; only the loss function does. Let's start with the data."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### The Data"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"For our example we are going to use the PASCAL dataset, which can have more than one kind of classified object per image.\n",
2020-02-28 19:44:06 +00:00
"\n",
"We begin by downloading and extracting the dataset as per usual:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-08-21 19:36:27 +00:00
"from fastai.vision.all import *\n",
2020-02-28 19:44:06 +00:00
"path = untar_data(URLs.PASCAL_2007)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"This dataset is different from the ones we have seen before, in that it is not structured by filename or folder but instead comes with a CSV (comma-separated values) file telling us what labels to use for each image. We can inspect the CSV file by reading it into a Pandas DataFrame:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fname</th>\n",
" <th>labels</th>\n",
" <th>is_valid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>000005.jpg</td>\n",
" <td>chair</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>000007.jpg</td>\n",
" <td>car</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>000009.jpg</td>\n",
" <td>horse person</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>000012.jpg</td>\n",
" <td>car</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>000016.jpg</td>\n",
" <td>bicycle</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fname labels is_valid\n",
"0 000005.jpg chair True\n",
"1 000007.jpg car True\n",
"2 000009.jpg horse person True\n",
"3 000012.jpg car False\n",
"4 000016.jpg bicycle True"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'train.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As you can see, the list of categories in each image is shown as a space-delimited string."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: Pandas and DataFrames"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"No, its not actually a panda! *Pandas* is a Python library that is used to manipulate and analyze tabular and time series data. The main class is `DataFrame`, which represents a table of rows and columns. You can get a DataFrame from a CSV file, a database table, Python dictionaries, and many other sources. In Jupyter, a DataFrame is output as a formatted table, as shown here.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"You can access rows and columns of a DataFrame with the `iloc` property, as if it were a matrix:"
2020-02-28 19:44:06 +00:00
]
},
2020-04-28 17:12:59 +00:00
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[:,0]"
]
},
2020-02-28 19:44:06 +00:00
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fname 000005.jpg\n",
"labels chair\n",
"is_valid True\n",
"Name: 0, dtype: object"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.iloc[0,:]\n",
2020-05-14 12:18:31 +00:00
"# Trailing :s are always optional (in numpy, pytorch, pandas, etc.),\n",
2020-02-28 19:44:06 +00:00
"# so this is equivalent:\n",
"df.iloc[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also grab a column by name by indexing into a DataFrame directly:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 000005.jpg\n",
"1 000007.jpg\n",
"2 000009.jpg\n",
"3 000012.jpg\n",
"4 000016.jpg\n",
" ... \n",
"5006 009954.jpg\n",
"5007 009955.jpg\n",
"5008 009958.jpg\n",
"5009 009959.jpg\n",
"5010 009961.jpg\n",
"Name: fname, Length: 5011, dtype: object"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['fname']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can create new columns and do calculations using columns:"
]
},
{
2020-04-23 13:41:55 +00:00
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-23 13:41:55 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
2020-04-28 17:12:59 +00:00
" <th>b</th>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
2020-04-28 17:12:59 +00:00
" <td>3</td>\n",
2020-04-23 13:41:55 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2020-04-28 17:12:59 +00:00
" a b\n",
"0 1 3\n",
"1 2 4"
2020-04-23 13:41:55 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-23 13:41:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"tmp_df = pd.DataFrame({'a':[1,2], 'b':[3,4]})\n",
"tmp_df"
2020-04-23 13:41:55 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
2020-04-23 13:41:55 +00:00
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" <th>c</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>6</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
2020-04-23 13:41:55 +00:00
"text/plain": [
2020-04-28 17:12:59 +00:00
" a b c\n",
"0 1 3 4\n",
"1 2 4 6"
2020-04-23 13:41:55 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-23 13:41:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-02-28 19:44:06 +00:00
"source": [
2020-04-28 17:12:59 +00:00
"tmp_df['c'] = tmp_df['a']+tmp_df['b']\n",
"tmp_df"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Pandas is a fast and flexible library, and an important part of every data scientists Python toolbox. Unfortunately, its API can be rather confusing and surprising, so it takes a while to get familiar with it. If you havent used Pandas before, wed suggest going through a tutorial; we are particularly fond of the book [*Python for Data Analysis*](http://shop.oreilly.com/product/0636920023784.do) by Wes McKinney, the creator of Pandas (O'Reilly). It also covers other important libraries like `matplotlib` and `numpy`. We will try to briefly describe Pandas functionality we use as we come across it, but will not go into the level of detail of McKinneys book."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
2020-03-03 23:04:23 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have seen what the data looks like, let's make it ready for model training."
]
},
2020-02-28 19:44:06 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Constructing a DataBlock"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How do we convert from a `DataFrame` object to a `DataLoaders` object? We generally suggest using the data block API for creating a `DataLoaders` object, where possible, since it provides a good mix of flexibility and simplicity. Here we will show you the steps that we take to use the data blocks API to construct a `DataLoaders` object in practice, using this dataset as an example.\n",
"\n",
"As we have seen, PyTorch and fastai have two main classes for representing and accessing a training set or validation set:\n",
"\n",
2020-05-14 12:18:31 +00:00
"- `Dataset`:: A collection that returns a tuple of your independent and dependent variable for a single item\n",
"- `DataLoader`:: An iterator that provides a stream of mini-batches, where each mini-batch is a tuple of a batch of independent variables and a batch of dependent variables"
2020-04-28 17:12:59 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-02-28 19:44:06 +00:00
"On top of these, fastai provides two classes for bringing your training and validation sets together:\n",
"\n",
2020-05-14 12:18:31 +00:00
"- `Datasets`:: An object that contains a training `Dataset` and a validation `Dataset`\n",
"- `DataLoaders`:: An object that contains a training `DataLoader` and a validation `DataLoader`\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"Since a `DataLoader` builds on top of a `Dataset` and adds additional functionality to it (collating multiple items into a mini-batch), its often easiest to start by creating and testing `Datasets`, and then look at `DataLoaders` after thats working."
2020-04-28 17:12:59 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"When we create a `DataBlock`, we build up gradually, step by step, and use the notebook to check our data along the way. This is a great way to make sure that you maintain momentum as you are coding, and that you keep an eye out for any problems. Its easy to debug, because you know that if a problem arises, it is in the line of code you just typed!\n",
2020-02-28 19:44:06 +00:00
"\n",
"Lets start with the simplest case, which is a data block created with no parameters:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 01:18:45 +00:00
"We can create a `Datasets` object from this. The only thing needed is a source—in this case, our DataFrame:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"dsets = dblock.datasets(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"This contains a `train` and a `valid` dataset, which we can index into:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(4009, 1002)"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"len(dsets.train),len(dsets.valid)"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"(fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object,\n",
" fname 008663.jpg\n",
" labels car person\n",
" is_valid False\n",
" Name: 4346, dtype: object)"
2020-04-28 17:12:59 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dsets.train[0]\n",
"x,y"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As you can see, this simply returns a row of the DataFrame, twice. This is because by default, the data block assumes we have two things: input and target. We are going to need to grab the appropriate fields from the DataFrame, which we can do by passing `get_x` and `get_y` functions:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"'008663.jpg'"
2020-04-28 17:12:59 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x['fname']"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"('005620.jpg', 'aeroplane')"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(get_x = lambda r: r['fname'], get_y = lambda r: r['labels'])\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As you can see, rather than defining a function in the usual way, we are using Pythons `lambda` keyword. This is just a shortcut for defining and then referring to a function. The following more verbose approach is identical:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"('002549.jpg', 'tvmonitor')"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return r['fname']\n",
"def get_y(r): return r['labels']\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Lambda functions are great for quickly iterating, but they are not compatible with serialization, so we advise you to use the more verbose approach if you want to export your `Learner` after training (lambdas are fine if you are just experimenting)."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"We can see that the independent variable will need to be converted into a complete path, so that we can open it as an image, and the dependent variable will need to be split on the space character (which is the default for Pythons `split` function) so that it becomes a list:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"(Path('/home/jhoward/.fastai/data/pascal_2007/train/002844.jpg'), ['train'])"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def get_x(r): return path/'train'/r['fname']\n",
"def get_y(r): return r['labels'].split(' ')\n",
"dblock = DataBlock(get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"To actually open the image and do the conversion to tensors, we will need to use a set of transforms; block types will provide us with those. We can use the same block types that we have used previously, with one exception: the `ImageBlock` will work fine again, because we have a path that points to a valid image, but the `CategoryBlock` is not going to work. The problem is that block returns a single integer, but we need to be able to have multiple labels for each item. To solve this, we use a `MultiCategoryBlock`. This type of block expects to receive a list of strings, as we have in this case, so lets test it out:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"(PILImage mode=RGB size=500x375,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]))"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" get_x = get_x, get_y = get_y)\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As you can see, our list of categories is not encoded in the same way that it was for the regular `CategoryBlock`. In that case, we had a single integer representing which category was present, based on its location in our vocab. In this case, however, we instead have a list of zeros, with a one in any position where that category is present. For example, if there is a one in the second and fourth positions, then that means that vocab items two and four are present in this image. This is known as *one-hot encoding*. The reason we cant easily just use a list of category indices is that each list would be a different length, and PyTorch requires tensors, where everything has to be the same length."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"> jargon: One-hot encoding: Using a vector of zeros, with a one in each location that is represented in the data, to encode a list of integers."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Lets check what the categories represent for this example (we are using the convenient `torch.where` function, which tells us all of the indices where our condition is true or false):"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-09-03 17:46:15 +00:00
"(#1) ['dog']"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = torch.where(dsets.train[0][1]==1.)[0]\n",
"dsets.train.vocab[idxs]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"With NumPy arrays, PyTorch tensors, and fastais `L` class, we can index directly using a list or vector, which makes a lot of code (such as this example) much clearer and more concise.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"We have ignored the column `is_valid` up until now, which means that `DataBlock` has been using a random split by default. To explicitly choose the elements of our validation set, we need to write a function and pass it to `splitter` (or use one of fastai's predefined functions or classes). It will take the items (here our whole DataFrame) and must return two (or more) lists of integers:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PILImage mode=RGB size=500x333,\n",
" TensorMultiCategory([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def splitter(df):\n",
" train = df.index[~df['is_valid']].tolist()\n",
" valid = df.index[df['is_valid']].tolist()\n",
" return train,valid\n",
"\n",
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y)\n",
"\n",
"dsets = dblock.datasets(df)\n",
"dsets.train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As we have discussed, a `DataLoader` collates the items from a `Dataset` into a mini-batch. This is a tuple of tensors, where each tensor simply stacks the items from that location in the `Dataset` item. \n",
"\n",
"Now that we have confirmed that the individual items look okay, there's one more step we need to ensure we can create our `DataLoaders`, which is to ensure that every item is of the same size. To do this, we can use `RandomResizedCrop`:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),\n",
" splitter=splitter,\n",
" get_x=get_x, \n",
" get_y=get_y,\n",
" item_tfms = RandomResizedCrop(128, min_scale=0.35))\n",
"dls = dblock.dataloaders(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we can display a sample of our data:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 17:46:15 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAACzCAYAAAD2UgRyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9edBuW37Xh31+a+3pGd7pvGe455479+1Wt7pbtNSSQQIMGIjLMVDYweWqODjGsUzFNhVSSdkuJwQn2BhIVXAcE0IZEhLHQyWEwglgYmNbVgQhRApIqNVq9aDbdz7TOz7DHtaQP9Za+9nPPs97zu1W9z2SeH+3zn2fZ49rr2ft9fv+vr9hifeea7mWa7mWa7mWa/n7W9TzbsC1XMu1XMu1XMu1PH+5BgTXci3Xci3Xci3Xcg0IruVaruVaruVaruUaEFzLtVzLtVzLtVwL14DgWq7lWq7lWq7lWrgGBNdyLddyLddyLdfCNSD4yCIi/7qIfO15t+Na/v4UEfkxEfmzT9n/50Xkr3+M7fnNIuJF5KWP657Xci3frjzr/bmWINnzbsC1XMu1fEfkf8A1wL+Wa7mWX4JcA4JruZZfBeK9P3/ebbiWa/lOiogU3vv2ebfj7ye5tih2iIiUIvKnReRcRE5F5E8D5WC/iMj/WES+ISKtiHxdRP7g6BrHIvJ/FZGliNwXkT8iIv/Hj5PWvZZfdaJE5I+JyCMRuRCRPysiE9jtMhCRf1JEfkpEahF5LCL/qYgcicjvE5EzEZmOjv/DIvKLIiLx+yfiGD4RkZWI/IyI/I6rGicib4rI/y1e+1RE/jMR+fx3oyOu5ZeXiMhvj7T8SZw3/2sR+QcG++ci8r8WkffiWPo7IvKPD/a/Fl1Q/5SI/FURWQJ/9CPOtW+JyL8Z34eL+H78cRG5Ur89q73xGC8i/4KI/Psiciki74jIvzw6Jovu5F+M79mXROT3/1L783nJNSDYLX8M+G8B/zTww8AS+BcH+/8F4I/E4z4L/C+BPyYi/73BMf8H4NcAvwP4h4CXgN/93W74tfyqlt8DHAO/EfingN8F/PFdB4rI7wP+z8BfAn4A+C3AXwM08B8DHvgnBscr4PcBf9Z770XkBeBvAkfxPp8H/hDgrrjfHeAngAexfb8O+ArwYyJy65fwzNfyK0PmwJ8i/O4/AnwV+GvRMBLg/0GYD/9J4HPAnwb+YxH5raPr/HHgPySMtz/FR5trAf4A8D7wQ8D/EPiXgD/47bR3dNwfBn4c+EK89x8Xkd8y2P9ngX8c+P3AZ4D/RTxm3L5fGeK9v/43+AfMgBr40dH2nwS+Fj+/A/yJ0f4/CXwjfv4kYcL9rYP9eTzvrz/vZ7z+9yvvH/BjwFuAHmz754Emjtk/PxxbwNvAv/uU6/07wE8Mvv/DQAfcjd//CPAhMLvi/N8cx/hL8fu/Dvyt0TECfB34g8+7/67/fbz/CMbmKQG4/uY4px6MjvnfA38pfn4tjqc/NDrmqXNt/P4W8P8aHfNHgXcH33+MAHaf2d7BNg/8O6Pjfh74t+Ln1wkA+dOjY/5nwN993r/Bt/PvmiF4Uj5BcA/8zdH2nwAQkX2Ctf/jo/3/NfBapGG/N277W2mn974jgIpruZZvV/62994Ovv8NoCCM2V5E5DbwMvCfPeVafwb49SKSxuqPAn/Fe/9B/P5F4G9675cfsW0/BHxRRBbpH3BJmOg/+RGvcS2/QkVEXo/U+tdE5AK4AA6AVwljowDeG42P/w5Pjo2/PbjmR5lrk/y/R8f8DeBevMa32t6h/N3R9/eAO/HzDxJA70+Onutf2/FcvyLkOqjwSZH491nLQI73y0c45lqu5Tspu8bcUK4cf977L4nITwD/nIj8MYJb4Hd/1PN3iAL+CwJVO5brgMdf/fKXgUcE1+o7QEswogrAEsbAD+04bxw0uAuAfpS5dizPOuZp7X1a+zwbV3v6+yPAasdxv+LkGhA8KV8jDIJfD/zcYPuPAHjvL0TkXeA3AX9lsP8fBH7Re78SkXTeDxMmSUQkI1hdv/Ddbf61/CqWHxIRPWAJfpgwVr8+PMh7/yCO0X+Y4Lu9Sv4M8G8DJwT3wF8b7Psp4EdFZPYRWYKfBP4Z4D3v/fojHH8tv0ok+t2/F/hveu//n3HbS8DteMhPAodA5b3/2Y963Y8y1w62/brR6T8MvO+9v/g22vtR5afi31e893/5Wzz3l6VcuwxGEie//x3wb4jI7xKR7xGRPwF8enDYvwX8ARH5URH5ZIwq/e8T/FZ4779KmIj/lIj8pkjL/hlgn1+hyPFaflnIMWFMfUZE/lGCn//fu0Jh/8+B3y8ifyge/1kR+ZdE5ObgmL8Q//4h4M9574cBg/9bwvzwn4jIr48U6+8QkX/kirb9u4SAxb8kIr8xRo3/hhj9/SO/lIe+ll/2cgo8JADIT4nIDwP/EZCA4X8J/HXgL4rIPyYib4jIF0XkD4jIjz7j2k+dawfyhRjt/ykR+W8T6nL8yW+zvR9JvPdfI8RB/Hsi8nslZNn8GhH5Z0XkX/lWrvXLRa4BwW75VwnR2f8+wad1SIhITfKnCYEj/xqBRfhXgH/Ve//nBsf8PuBngf+UENDyHvCfE4JrruVavh35CwS//E8QMgX+KvAv7zrQe/9nCRb77yH4QX8c+EcAMzimJozxDPhzo/M/AH5DvN9fBb4E/JtcQcV67+8TrLJHwF8kZBj8BwSf7Ae7zrmWXx0SgeQ/QYhl+RlCgOu/TfzdfYi0+12EcfG/IgTm/RXgH2XEbu2QjzLXAvxvCGPtJwng9E9zBSB4Vnu/Rfnn433+J7F9/wXw3wW+8W1c67mLxKjIa/kui4howovwf/fe/4+ed3uu5VoAROT/Aky897/zebflWq7l2xEReYuQQfBvPO+2/EqX6xiC75KIyD9I8En9HWCPkBv7GgGNXsu1PFcRkSNCvYB/DPjtz7k513It1/LLQK4BwXdPNPA/Bd4k5Hf/LPBbvPd/77m26lquJcjfIcQk/Anv/Y8957Zcy7Vcyy8DuXYZXMu1XMu1XMu1XMt1UOG1XMu1XMu1XMu1XAOCa7mWa7mWa7mWa+EZMQS/7tf9uif8CWGdiiDOuf5vcj1orZ84dlDjGaVUv2+4/Vmui1Gt6K1rpXYM7ysiOOf6tqXzhu1P17XW9p/TtdI51lqMMTjn6LqOruuw1vbXHn4ft+9atsV7/1EqjH1H5ff+W/+RV1pTViXeg+kM3nokU3hnKauMplmjlEYpjfeCVhnT2RTw1KsFuVYoLVhrqdc1lydnKAQ3KQGPUoISjXeKrqnBGHI0qq2pTz7ArheUZcXB3bvMbhxz89aLnJw9Qilhf/+QajLBOotWhOsB3lgynWG9o27WmK7j6MYhs/mcpu2YzedU1QQvitY4vBI8Hu8deZaB97Rdzbq5ZHF5yfn5GfXZOcvFAr+u0W1HhkLNpxSzA1rnMEChS/ZmM2azKXlZUVQleaEpqwk6y1FKgwim61gvlthlw+rDx3z1b/9XPDp9l9Z7dDllMp9QHd6gzeesyWhNi/MOxOOtpa2XzKYV0+mM1XpJW18iZo3Ua9rzS7q6gUz4zCd/gJu3v4cXXnmVo6NDFg/O+M//8l/k9PIdOJixd3zMjcmMsgNfd0wL8KWw6Fo8iq72KApUUSATmO3tgc1xrmF/MuPyYsXJ40sWZwsW9YJqJig8nXdonWO7jrZdoZVFZcJf+0/+wsc+hgH+8B/9814kQ3obzoe+jN9c+tT/2Z7vBCFljIrAplR/EtmaG4efh3OaiKCU2p5HxW+fO/g7vG+4jn9iDg6HCIiQdoW5XcZN3Nm+cRvH28O87knrcimlUaIQUYgS0juc9iXdtGnHRqf01x3cXimFFvXEMZu20j8XSlCiULLRk/25gBcJz5n0FbEP/aY7PJvfXQB8+P299zjv6DpLXXe0bYu1LVqp+E8jSiFK8y/+vt+6cxw/FRBkWYb3futBh3+HynwICp6m6J1z/TXGinx87qZDtwfquNOVUk+cswsAjAf5zsE9Ot851wOG4b4EAqy1dF3XfzbG9PuMMRhjnuiXXc9
2020-02-28 19:44:06 +00:00
"text/plain": [
"<Figure size 648x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-03-19 13:21:55 +00:00
"dls.show_batch(nrows=1, ncols=3)"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Remember that if anything goes wrong when you create your `DataLoaders` from your `DataBlock`, or if you want to view exactly what happens with your `DataBlock`, you can use the `summary` method we presented in the last chapter."
2020-02-28 19:44:06 +00:00
]
},
2020-03-03 23:04:23 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Our data is now ready for training a model. As we will see, nothing is going to change when we create our `Learner`, but behind the scenes, the fastai library will pick a new loss function for us: binary cross-entropy."
2020-03-03 23:04:23 +00:00
]
},
2020-02-28 19:44:06 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Binary Cross-Entropy"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Now we'll create our `Learner`. We saw in <<chapter_mnist_basics>> that a `Learner` object contains four main things: the model, a `DataLoaders` object, an `Optimizer`, and the loss function to use. We already have our `DataLoaders`, we can leverage fastai's `resnet` models (which we'll learn how to create from scratch later), and we know how to create an `SGD` optimizer. So let's focus on ensuring we have a suitable loss function. To do this, let's use `cnn_learner` to create a `Learner`, so we can look at its activations:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"We also saw that the model in a `Learner` is generally an object of a class inheriting from `nn.Module`, and that we can call it using parentheses and it will return the activations of a model. You should pass it your independent variable, as a mini-batch. We can try it out by grabbing a mini batch from our `DataLoader` and then passing it to the model:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 20])"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-09-03 17:46:15 +00:00
"x,y = to_cpu(dls.train.one_batch())\n",
2020-02-28 19:44:06 +00:00
"activs = learn.model(x)\n",
"activs.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 01:18:45 +00:00
"Think about why `activs` has this shape—we have a batch size of 64, and we need to calculate the probability of each of 20 categories. Heres what one of those activations looks like:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"TensorImage([ 0.7476, -1.1988, 4.5421, -1.5915, -0.6749, 0.0343, -2.4930, -0.8330, -0.3817, -1.4876, -0.1683, 2.1547, -3.4151, -1.1743, 0.1530, -1.6801, -2.3067, 0.7063, -1.3358, -0.3715],\n",
" grad_fn=<AliasBackward>)"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"activs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"> note: Getting Model Activations: Knowing how to manually get a mini-batch and pass it into a model, and look at the activations and loss, is really important for debugging your model. It is also very helpful for learning, so that you can see exactly what is going on."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 01:18:45 +00:00
"They arent yet scaled to between 0 and 1, but we learned how to do that in <<chapter_mnist_basics>>, using the `sigmoid` function. We also saw how to calculate a loss based on this—this is our loss function from <<chapter_mnist_basics>>, with the addition of `log` as discussed in the last chapter:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(inputs, targets):\n",
" inputs = inputs.sigmoid()\n",
" return -torch.where(targets==1, 1-inputs, inputs).log().mean()"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Note that because we have a one-hot-encoded dependent variable, we can't directly use `nll_loss` or `softmax` (and therefore we can't use `cross_entropy`):\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"- `softmax`, as we saw, requires that all predictions sum to 1, and tends to push one activation to be much larger than the others (due to the use of `exp`); however, we may well have multiple objects that we're confident appear in an image, so restricting the maximum sum of activations to 1 is not a good idea. By the same reasoning, we may want the sum to be *less* than 1, if we don't think *any* of the categories appear in an image.\n",
"- `nll_loss`, as we saw, returns the value of just one activation: the single activation corresponding with the single label for an item. This doesn't make sense when we have multiple labels.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"On the other hand, the `binary_cross_entropy` function, which is just `mnist_loss` along with `log`, provides just what we need, thanks to the magic of PyTorch's elementwise operations. Each activation will be compared to each target for each column, so we don't have to do anything to make this function work for multiple columns."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"> j: One of the things I really like about working with libraries like PyTorch, with broadcasting and elementwise operations, is that quite frequently I find I can write code that works equally well for a single item or a batch of items, without changes. `binary_cross_entropy` is a great example of this. By using these operations, we don't have to write loops ourselves, and can rely on PyTorch to do the looping we need as appropriate for the rank of the tensors we're working with."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTorch already provides this function for us. In fact, it provides a number of versions, with rather confusing names!\n",
"\n",
2020-05-14 12:18:31 +00:00
"`F.binary_cross_entropy` and its module equivalent `nn.BCELoss` calculate cross-entropy on a one-hot-encoded target, but do not include the initial `sigmoid`. Normally for one-hot-encoded targets you'll want `F.binary_cross_entropy_with_logits` (or `nn.BCEWithLogitsLoss`), which do both sigmoid and binary cross-entropy in a single function, as in the preceding example.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"The equivalent for single-label datasets (like MNIST or the Pet dataset), where the target is encoded as a single integer, is `F.nll_loss` or `nn.NLLLoss` for the version without the initial softmax, and `F.cross_entropy` or `nn.CrossEntropyLoss` for the version with the initial softmax.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"Since we have a one-hot-encoded target, we will use `BCEWithLogitsLoss`:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"TensorImage(1.0342, grad_fn=<AliasBackward>)"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_func = nn.BCEWithLogitsLoss()\n",
"loss = loss_func(activs, y)\n",
"loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"We don't actually need to tell fastai to use this loss function (although we can if we want) since it will be automatically chosen for us. fastai knows that the `DataLoaders` has multiple category labels, so it will use `nn.BCEWithLogitsLoss` by default.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"One change compared to the last chapter is the metric we use: because this is a multilabel problem, we can't use the accuracy function. Why is that? Well, accuracy was comparing our outputs to our targets like so:\n",
2020-02-28 19:44:06 +00:00
"\n",
"```python\n",
"def accuracy(inp, targ, axis=-1):\n",
" \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n",
" pred = inp.argmax(dim=axis)\n",
" return (pred == targ).float().mean()\n",
"```\n",
"\n",
"The class predicted was the one with the highest activation (this is what `argmax` does). Here it doesn't work because we could have more than one prediction on a single image. After applying the sigmoid to our activations (to make them between 0 and 1), we need to decide which ones are 0s and which ones are 1s by picking a *threshold*. Each value above the threshold will be considered as a 1, and each value lower than the threshold will be considered a 0:\n",
"\n",
"```python\n",
"def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):\n",
" \"Compute accuracy when `inp` and `targ` are the same size.\"\n",
" if sigmoid: inp = inp.sigmoid()\n",
" return ((inp>thresh)==targ.bool()).float().mean()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"If we pass `accuracy_multi` directly as a metric, it will use the default value for `threshold`, which is 0.5. We might want to adjust that default and create a new version of `accuracy_multi` that has a different default. To help with this, there is a function in Python called `partial`. It allows us to *bind* a function with some arguments or keyword arguments, making a new version of that function that, whenever it is called, always includes those arguments. For instance, here is a simple function taking two arguments:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Hello Jeremy.', 'Ahoy! Jeremy.')"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def say_hello(name, say_what=\"Hello\"): return f\"{say_what} {name}.\"\n",
"say_hello('Jeremy'),say_hello('Jeremy', 'Ahoy!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can switch to a French version of that function by using `partial`:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('Bonjour Jeremy.', 'Bonjour Sylvain.')"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = partial(say_hello, say_what=\"Bonjour\")\n",
"f(\"Jeremy\"),f(\"Sylvain\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now train our model. Let's try setting the accuracy threshold to 0.2 for our metric:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-09-03 18:29:20 +00:00
" <td>0.942663</td>\n",
" <td>0.703737</td>\n",
" <td>0.233307</td>\n",
2020-11-29 14:51:47 +00:00
" <td>00:08</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.821548</td>\n",
" <td>0.550827</td>\n",
" <td>0.295319</td>\n",
" <td>00:08</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.604189</td>\n",
" <td>0.202585</td>\n",
" <td>0.816474</td>\n",
" <td>00:08</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.359258</td>\n",
" <td>0.123299</td>\n",
" <td>0.944283</td>\n",
" <td>00:08</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.135746</td>\n",
" <td>0.123404</td>\n",
" <td>0.944442</td>\n",
" <td>00:09</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.118443</td>\n",
" <td>0.107534</td>\n",
" <td>0.951255</td>\n",
" <td>00:09</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.098525</td>\n",
" <td>0.104778</td>\n",
" <td>0.951554</td>\n",
" <td>00:10</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))\n",
"learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Picking a threshold is important. If you pick a threshold that's too low, you'll often be failing to select correctly labeled objects. We can see this by changing our metric, and then calling `validate`, which returns the validation loss and metrics:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"(#2) [0.10477833449840546,0.9314740300178528]"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.1)\n",
"learn.validate()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"If you pick a threshold that's too high, you'll only be selecting the objects for which your model is very confident:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"(#2) [0.10477833449840546,0.9429482221603394]"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.metrics = partial(accuracy_multi, thresh=0.99)\n",
"learn.validate()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can find the best threshold by trying a few levels and seeing what works best. This is much faster if we just grab the predictions once:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"preds,targs = learn.get_preds()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Then we can call the metric directly. Note that by default `get_preds` applies the output activation function (sigmoid, in this case) for us, so we'll need to tell `accuracy_multi` to not apply it:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"TensorImage(0.9567)"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_multi(preds, targs, thresh=0.9, sigmoid=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now use this approach to find the best threshold level:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-11-29 14:51:47 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD7CAYAAABt0P8jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlbklEQVR4nO3de3yU5Z338c8v5wMkISQkEgiBCHgsKhFRV8Weu63V1m4PUrpdtfbBWt19ttulz8rLrlu33XZfq+V51K3dduuqpa67Htu1tbaiVVQEBZVWAkgSCGDOgUzOk9/zx0xojINMQpJJ5v6+X695Jfc1V+75zU34zpXrvuYec3dERCQ4UhJdgIiITCwFv4hIwCj4RUQCRsEvIhIwCn4RkYBJS3QB8SgqKvKKiopElyEiMqVs2bKlyd2Lh7dPieCvqKhg8+bNiS5DRGRKMbPaWO2a6hERCRgFv4hIwCj4RUQCRsEvIhIwCn4RkYBR8IuIBIyCX0QkYKbEOn6Rqaa7L8xbh7o52N7NwUPdvHWom+6+AXIz08jNSCU3M41pmWmR7cxUcjPSjrRlpadgZol+CpLEFPwiwMCA09TRw/72bpo7ehhwGHDHHdwd54/bA9HPsBhwp6t3IBLsQwL+4KFu2jr7Rl1LikFORhqZaSlkpaeSmZZCxpDvh381g/6w0z/ghAec/oGBmNthd2bmZjC7IJsT8rOZXZBFWUE2swuymTU9k7RUTQAEhYJfkl5Pf5i2zj4OtHdzsL2LA+3df7y1RbbfOtRN/8DoPpTIDGbmZlKan8mcGdksnTeD0rwsSvKzKM3LojQ/i5K8LHIyUunsCRPq7SfU009HTz+hIduRtjChnn66+sL09Ifp7hugp3+A7r7wka9tXX30DNkGSE0x0lNTSE0x0lIs8jU15cj3GWkppJhR39bNSzWttHe9/YUpNcUozctidkHWkReGkrxMiqdnUjwt8nVWXha5Gan6ayQJKPhlSgkPOG8d6qa+rYv9bV20hnpp7+qnrauX9q4+DnX10T7s1t038I79ZKSlcEJ+FifkZ7FsfuGR70/Iz6ZoeiZpKZFwSzHDbOhXACPFwMzITEuheHom6XGOlvNzUsjPSR/DIzI6HT39HGjrih7HbvZHj2d9Wxcv17VysP0AfeF3vhBmp6dGXgyGvCCU5GUytzCHuYU5lBfmMDM3Qy8Ok5yCXyaVvvAAB9q62dfayb62Lupbu9jX2kV9Wyf1bV0caIs9Mp+WmUZ+djp52enkZ6cxvyiX/Oz0P95yMijNyzoS8IUBD6dpmWksLJnOwpLpMe8fGHDauvpoPNxD4+EeGg53H/m+sSPydVdjB8+/2fyOvx5yMlIpj74IlBfmUD7zjy8Kc2Zkk5mWOhFPUd6Fgl8Sri88wLM7m3h4az1PbH+Lruj0BUSmUUrzInPRZ5XPYM6SbMoKciibkU1ZQRaFuZnkZaVpfnqMpaQYhbkZFOZmsLg09ovDoO6+MHtbOqkbctvb0klNc4hndja+7S8uMyieFpkSK5uRQ1lBdvT7bOYURL7mZCiWxpuOsCSEu/PK3jYeeaWen796gOZQL/nZ6Vx2ZhlnlhcwZ0Y2cwpyKM3PIiNNoT6ZZaWnHvWvB3en8XDP214U6lsjU0qv7mvjl6+/c0qpMDcjetI5KzqlFPk6a3CKaXomRdMy9XtxHBT8MqF2N3bwyCv1PLJtP7XNnWSmpfD+k0u49IzZXLS4WNMAScbMmJWXxay8LKoqCt9x/8CA03C4h/q2TvYdmdaLTPHtaQrxUk0rLaHemPsuyEk/cp5hQXEu51UWce6CmczIzRjvpzXlmfvoVjJMpKqqKtf1+KeuA+1d/OLVAzyydT+v1beTYnBeZRGXnjGbD59WyvSsxJ/slMmrt3+A5lDPH88xHO6h4W3fd1P9VgcdPf2YwSkn5HH+iUWcWzmTZRWF5GYGd3xrZlvcveod7fEEv5kVAj8CPgg0Ad9w95/G6JcJfAf4DJANrAducPe+IX0+C9wElAMHgS+6++/e7fEV/FOHu1PT3MmmPc1s2tPKSzUt1LV0AnB6WT6XnjGbjy+Zzay8rARXKsmkLzzAq/va2biried2N/FybRu94QHSU40z5hZwXmUR559YxBlzCwI1RXS8wb+eyOUdrgLOAH4BnOfu24f1uwl4P3ApkAo8BvzK3W+K3v8B4N+IvDBsAk4AcPf6d3t8Bf/kFR5w3jh4iE17WnippoVNe1pp6ugBInO1VfNmsGx+ISsWz+LEWdMSXK0ERVdvmM21LTy3q5mNu5t4rb4d98iKo4sXz2LlOeWcWzkz6Vd2jTr4zSwXaAVOc/fqaNs9QL27rxnWdzPwT+7+QHT7iuj23Oj2RuBH7v6jkRSv4J9c+sMD/PfL+3j89YNsqWnlcE8/AGUF2SybX8jZFYUsmz+DyuJpSf8fS6aG9s4+XtjTzLM7m3js1f20dfaxoCiXzy0r51NL5yTteYGjBX88k1+LgPBg6EdtAy6K9TjR29DtOWaWD3QAVcCjZrYLyAIeBv7G3btiFHwNcA1AeXl5HGXKRHi6upFbfvF7qt/qYEFRLpecMZtlFYWcPb+QsoLsRJcnElN+TjofOrWUD51ayt999GQef/0A971Qxy3/8we+98QOPnr6Caw8p5yl82YEYrASz4j/AuABdy8d0vYlYKW7rxjW91vAxcBlRKZ6HgGWAbOJvAjUA1uAS4C+6P0b3P3v3q0GjfgTr/qtw9zyiz/wdHUj5YU5fOMjJ/Hh00oD8Z9EktcbBw/x0xfreOjleg739LO4ZDorl5dz2Zll5CXBooPjmeo5E3jO3XOGtP01sMLdLxnWNxv4HvAJoAf4IfD3RE705gEtRE7m3h3tfzlwo7uf+W41KPgTp/FwD7c+Wc3PNtWRm5nGDe9byKpz52nZpSSVzt5+Htu2n3tfqOO1+nay01P5+JLZfH75PE6fk5/o8kbteKZ6qoE0M1vo7jujbUuA7cM7RqdsroveBqdrtrh7GGg1s33A5F8/KnT3hfnxc3u446nddPeF+cK5FVz/voUUJulcqARbTkYanzm7nM+cXc6r+9q474U6Ht22n/s372XJnHxWLp/HJe+ZTXZGcgx44l3V8zMigX01kVU9/0PsVT1l0X4HgHOAB4Cr3P2J6P03Ax8BPkpkqudRIlM9a9/t8TXinzjuzqPb9vPdX+6gvq2L959cwjf+9CQqi7UiR4KlvauPh17ex70v1rGroYP87HQ+tXQOK88pZ8EU+f8wFuv4fwx8AGgG1rj7T82sHPg9cIq715nZhcB/ALOAvcDN7n7fkP2kA98HrgC6gf8Evu7u3e/2+Ar+ibFtbxs3PbqdrXvbOOWEPG786Mmcd2JRossSSSh354U3W7j3xVp+9fpB+gecPzmxiM8vL+f9J5dM6utEHVfwJ5qCf3x194W57cmd3PXMboqmZfK1Dy3m8rPmkJqiE7ciQzUc7ub+TXtZv6mO/e3dlORl8rll5XxuWTklk/BNiQp+iWnr3ja+9sA2djV08Jmqufzdx05OitUMIuOpPzzAUzsaueeFWp6pbiQ1xfjYe05g9YpKTirNS3R5RxzPyV1JQt19Yb7/m5384OndlORl8ZO/OJsVi2cluiyRKSEtNYUPnFLCB04pobY5xH88X8v6TXU8snU/7ztpFtdeXMnSee+8KN1koRF/AG3d28bfPLCNnRrli4yZ1lAvdz9fw0821tDW2cey+YV85eITuXBhUcLe76KpHqGnPzKXPzjK//YnT9coX2SMdfb2s37TXn74zJscPNTNqbPzWL2iko+cdsKEnzdT8Afctuhc/s6GDj5dNYcbP3aKRvki46i3f4CHX6nnX5/ezZtNIeYX5fLlCxfwibPKJuwNkAr+gNIoXySxwgPOr7Yf5I4Nu3i9/hCleVn8y2eWcF7l+C+VVvAH0L7WTlbf+zKv1bdrlC+SYO7Os7ua+PvHfk9NU4jvXP4ePrV0zrg+5tGCf/K+80COy9PVjXzs/z5LTVOIu1Yt5bufWqLQF0kgM+OChcX89+rzOGdBIV97YBv/8sQOEjH4VvAnmYEBZ91vdvL
2020-02-28 19:44:06 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"xs = torch.linspace(0.05,0.95,29)\n",
"accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]\n",
"plt.plot(xs,accs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"In this case, we're using the validation set to pick a hyperparameter (the threshold), which is the purpose of the validation set. Sometimes students have expressed their concern that we might be *overfitting* to the validation set, since we're trying lots of values to see which is the best. However, as you see in the plot, changing the threshold in this case results in a smooth curve, so we're clearly not picking some inappropriate outlier. This is a good example of where you have to be careful of the difference between theory (don't try lots of hyperparameter values or you might overfit the validation set) versus practice (if the relationship is smooth, then it's fine to do this).\n",
2020-03-03 23:04:23 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"This concludes the part of this chapter dedicated to multi-label classification. Next, we'll take a look at a regression problem."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's easy to think of deep learning models as being classified into domains, like *computer vision*, *NLP*, and so forth. And indeed, that's how fastai classifies its applications—largely because that's how most people are used to thinking of things.\n",
"\n",
2020-05-14 12:18:31 +00:00
"But really, that's hiding a more interesting and deeper perspective. A model is defined by its independent and dependent variables, along with its loss function. That means that there's really a far wider array of models than just the simple domain-based split. Perhaps we have an independent variable that's an image, and a dependent that's text (e.g., generating a caption from an image); or perhaps we have an independent variable that's text and dependent that's an image (e.g., generating an image from a caption—which is actually possible for deep learning to do!); or perhaps we've got images, texts, and tabular data as independent variables, and we're trying to predict product purchases... the possibilities really are endless.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"To be able to move beyond fixed applications, to crafting your own novel solutions to novel problems, it helps to really understand the data block API (and maybe also the mid-tier API, which we'll see later in the book). As an example, let's consider the problem of *image regression*. This refers to learning from a dataset where the independent variable is an image, and the dependent variable is one or more floats. Often we see people treat image regression as a whole separate application—but as you'll see here, we can treat it as just another CNN on top of the data block API.\n",
2020-02-28 19:44:06 +00:00
"\n",
2020-05-14 12:18:31 +00:00
"We're going to jump straight to a somewhat tricky variant of image regression, because we know you're ready for it! We're going to do a key point model. A *key point* refers to a specific location represented in an image—in this case, we'll use images of people and we'll be looking for the center of the person's face in each image. That means we'll actually be predicting *two* values for each image: the row and column of the face center. "
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Assemble the Data"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"We will use the [Biwi Kinect Head Pose dataset](https://icu.ee.ethz.ch/research/datsets.html) for this section. We'll begin by downloading the dataset as usual:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.BIWI_HEAD_POSE)"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what we've got!"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#50) [Path('01'),Path('01.obj'),Path('02'),Path('02.obj'),Path('03'),Path('03.obj'),Path('04'),Path('04.obj'),Path('05'),Path('05.obj')...]"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"path.ls().sorted()"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"There are 24 directories numbered from 01 to 24 (they correspond to the different people photographed), and a corresponding *.obj* file for each (we won't need them here). Let's take a look inside one of these directories:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 17:12:59 +00:00
"(#1000) [Path('01/depth.cal'),Path('01/frame_00003_pose.txt'),Path('01/frame_00003_rgb.jpg'),Path('01/frame_00004_pose.txt'),Path('01/frame_00004_rgb.jpg'),Path('01/frame_00005_pose.txt'),Path('01/frame_00005_rgb.jpg'),Path('01/frame_00006_pose.txt'),Path('01/frame_00006_rgb.jpg'),Path('01/frame_00007_pose.txt')...]"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"(path/'01').ls().sorted()"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Inside the subdirectories, we have different frames, each of them come with an image (*\\_rgb.jpg*) and a pose file (*\\_pose.txt*). We can easily get all the image files recursively with `get_image_files`, then write a function that converts an image filename to its associated pose file:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Path('13/frame_00349_pose.txt')"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img_files = get_image_files(path)\n",
"def img2pose(x): return Path(f'{str(x)[:-7]}pose.txt')\n",
"img2pose(img_files[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Let's take a look at our first image:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(480, 640)"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im = PILImage.create(img_files[0])\n",
"im.shape"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-04-28 17:12:59 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAAB4CAIAAAD6wG44AABe9UlEQVR4nM39aaxl2XUmiH1r7X3OvfdNMeYUOWcymcwkk6Q4SKJUsqzS2C6pqlpuw3C7baAHN+A//tNA223YBRsFN2DDcBvucrddqEYZ1XbJrRooUUOVVJIocSZFMpOZJHOeIjIjMmN48eIN995z9l7r84+9z7n3RUQmSamq2geBiBv3nmGfvfaavjVs+Tt/5//+8Y9/nCQAFwJQQgCHYDhWnwCAa99z/C9lPFfEV1dQiJsOAhDc+v2tRzmHEIBlGFIfKxCSJB0wM6fnnNzcPbibWWdGd3d3M6fT3Zzm7kY4Sc9Og4EA3Qk6xWkEQYcTQJ0TIyAgOfwLkKSwniOQ9e9JA+gE6YJ6DgkCZALgJoQDYHlUhjOZ5Vnb7JzcdC93EsLK7cpAyDqe8kw6SQF8OAdexgzSHXASgJ4/fyH+2Cc+/lOf+akfPNf/bR8EHHDCjTRYtmxm2VL2vk8555QTmXO2nLIzWxYzz5bdDOZkds9mZurZMplJwtWNpJs5XACS5p7pBEHSzFI6DDGCNHdhY+6guzs9DxQlYfTxM0m6F0p3QP0NFCCSdEM2AHRPZtnEYOjMTp254yNPPdb3ybKXhezOYZkVApdlQVJAIYW0gboOEITTh6sEEFJiWU0jh5RP5T8ysOcaP4JeliWcMObxtRKUdApBJ6RwNMmYDYS7uxOidLhDIDQxJ0kz9jR3p9MGhjM6zLN15ubOnLN55+65nJFp5maWU7Ls7uaFSszmvZsTQiZzKywsZqQT7jRHItw9Ox0mJJwEPdIAkg7x7EXIwN1p68sslylxUlg/ExABK4UB1pkkCWiZM5DuBmhhcTdxN2N2T+YJCYvOZie2CqncnRRAK2nLdA6cSvpAXSkPLc9yL+Qv/xb5AoHGN14/PHvm3a7LZm5Myc3K+nMj6x1inwUghKSb5Cr+6Jnu7ubm5sxkdiS3npqHNwRZ3x4kEMgq9IbVRwLueXw39+S0YTEaCALuJjKIOhI0rljGyqJjFWW9u4OxTispoiiLHS5CKQzmJgIpIg6kU9kCEGHKfdf3sYkxNnSnUUREim4olIOTIlAJlQfEKyOQFAepKiKKgT0IUVFhJEihqAtUfeLSaFgGELB2GoEgDCqhTHXlL2G5/yCiV0cZVfm+fAZUle6VuiIav/H1L1+79o57gjA4ASur2K1cQJLwOnt1aVTNSEGRZUIQCIXoAIP0JJ1F1Ii7ZzM61QNAIkEc1HFwFigiObtAGzik6KpAiYUfBFImVoSgDwzjgEIFCjoBF9EilwRJCAFERRSqIztBpYUSwVQBgQi0nIQgAg1ilmbTuFgsQzsTBIEIEEIAEJBFVCQItQxKRSCCJgMiiICUp0gluYKiKhABlTAgE+amdKeTxgSDJ/O8c/Ys3LUoa4iNclMqLVWVdFKreBWwSE3AARQCi8MpUqSIExavXX5t1vaq0vddEyey4ieoSBHnWtZv/VtFVABVFU0iQaRRBJVUlq0IIBQiqpTXh6q2ASMfyBQQgapqua2Lxja65+VysdFOIBQIoMTaVWoiI0niSjNrEIFIIXAodyYhamW0w3SrSBCoKEWAQTENs1+eI0UexMCuT7GZaWhk+AmAM6gSMIC6prhUfZCK4gPbCTwn0r1pGlE1B2kQc5pl0IliQzhgKZm1zZbWB61z5E2MW36vBqqIQBWkuANlDAIoxKVYuyJxZ/PkHafumkymi8WinVBFNMS6PBWqQUQCJARFfdOoKgBURYQiUthE2K+GJSLlFjIcZbJDFSkCoaxEX59T27TzxSKEMGlUECEi4pWDKyXIeoVwWGhOEgqwrKrlso8hNs0UIGHjtXVFU0kh+mrqQwb7AygaT7hYHIlAhRFNbKaxmWhREiAAkwAUgenKNQJD6C4DI1Wh5wgBGqKoEBAFWNSqQhwwoxjhZrnrl31uN0RER2KO4rdMI+swIIBXVSEcXBgRkAZoWcdgz5FaO2cmJ87OptON2bLZmjZBQwgNoCKF2KqiqlJEhKCQG6qqoiJRlaJ2Y383xFNnzp6tukeDig5klSqkBFEaVYpS1FVFVFTV3V94/pmNjenZ02cvX979wGMPbm+fBgjJkLi2flUg7kKHax6MEZqjeAvm6dLFdzc2Nra2dsyKV8SVn0QjM0GypddFz2pYgXRxh5AQDUHAEBhiKxoFKwclSFkcSoZ1Di40YzHBpKwbhUJoZZGD5XsQAQQkQwlxJ9x78y6bcfBMpazWgXLlc1VTGKRDIf+ayzqYxJVxx/vERx96YGMSP/zkh15/48177j93//33QyRoDBpFNMYoqkGziIZQaCJlQYtAEEUMyNd2r8Rw4tTp0ytndW1AP/Bw950zp7rDyel7u0cfvifEeLuz6OZOmlk25mwpJZK5GNbZsuWtHTt16iSAnK0cxVwsDk/xiulWDS11upLVy3amQIYQtIkKMeQQY1At0r36vjLDwFZr/4JwgMIAAZEDBQxOepEUIYAIVCl+jhBiDnFVDao5KCLgoupQhxCEQgr5WARNU59TbHQpbi7BIigJSLGgpFJdQQABaOJDDz3+kY88dfbs2XPnHjx59mTTToZ76Q+mTLW28p13PgSEkdsGI+yHoy55x50Pbp7c3OPy1LnTIUbeYi6WQ1TFXVWVEoIUl1NEzVREIRJj0zTFGDZVUzVWtMNDCO5uZqBx+BZUFl+IDBAlVWIIjULMEENUDTIAMnUp1I9c/QMoQIpoJQqEpJIuEqRacJUFy+8qg36SYboEEoIUAVseSUH1k24zE1UTjjK8LIvx1EJ3UYHE2EzvuvschCdPn2radm1yVzdmNWNvD0ARkSyo1l/4iEaRoIM39Z5HmRVVIUOMxcGvh7urxhDKYldAB7jA3M2s2Fygl2VRVku5qvCxBneVELQRIChDaKSIV9anFMdnQB44EAjFDJUKXDVF64s4EFRVNRSlPJ5eXqQskGpWFWACQaDEwKiig7O0mpOV8YXqS0ldCWVabloTjBpUo7ibhML4txOsUh2rmw4KKqP/ZYgLUAqyoJTw/mcW+hazeSB2QQ8YQrEXQmGtECjSuFe6lqtJocDNVcUpChaTrfjp5ZwYI81FVDUAonXVce0+dSQrZiiojrPyqUDEzaiqg+EDiIkOdpgUHlZZHSpSMOIiOEfciWuA062zUWjKm08oRq4IgCjF05NQnJbb34i3V6cymJd/mUMAChowCCKcEt7/lqripGqZdC/+1sCUGoIWCayqpJUv3alagCmnA0HdrfpiBZkTFyeVDKIqQngIIQRzoTgoFc7gIBkxuOOFxsU7MIdAq4hTkUhmVagKnRQDHUKgkBYqFAkm4ioQBC1GX1kCuqYBKGoFngSEUA641sD9PvizVUi4K8uyEsSVxSW3p+K/+kOa5rZW1XtfIBgpV9bvyAhF1xaSlyXrXtixak2oirtIUDiFpLh78XQAhtiIKMRVJZuF0ErBF91XtF0Tg9WZKaPSanoN54hI5eCiIkWErA7rLe9EESk+uqxuM96/eFoVR1uJ6fHi0fqpiGb5Xaoz+yNN7r/0QwRt0xQlxNWy/UEXiQAIQUII1Y0TGZGT8Zvx+3K+qpbPWkGYY2eJBNWYkgmQc54fHakq1lyO8SjCcV3mrYY9OocixSstuDGKVoFiUNgVJijMVT7SCS9+9rE7UyFKCr1ch5ueO46kjhMqCIIAqILwYoD8xWn0lzwkpdR3nYh4CQr8wAtWi3I1m+s0vunLW7+RgcgDbVU1AkJn6lKZfg2hQDrD5Vi/2y0jud04a9Cl/FEUmFMCbjZVy4rRwZ6q0cXjN5eCLWPtdW73zNGur79GkG4eBtTph3Vu/mUehINQh9EixdFqMdffcyjig/UhsjZXIqIaSJj5SI91DV0CNZWboU4C1PLWQqUUKNihDq34JQIqfFjt1io2JaCKRAddBpur4ABVwKoQHjR6DRhK9YgkYAgKCYKIClS14rIcvFvSi0FOUZJUBwFb0wrQ6h2xqgGSWmP5Thgkx+I/1LHevGr+NR3FxXR3g1NLiO1HkyhF9ha7A1VMKmByTGD
2020-02-28 19:44:06 +00:00
"text/plain": [
2020-11-29 14:51:47 +00:00
"<PIL.Image.Image image mode=RGB size=160x120 at 0x7F2DF0A49690>"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"im.to_thumb(160)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"The Biwi dataset website used to explain the format of the pose text file associated with each image, which shows the location of the center of the head. The details of this aren't important for our purposes, so we'll just show the function we use to extract the head center point:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)\n",
"def get_ctr(f):\n",
" ctr = np.genfromtxt(img2pose(f), skip_header=3)\n",
" c1 = ctr[0] * cal[0][0]/ctr[2] + cal[0][2]\n",
" c2 = ctr[1] * cal[1][1]/ctr[2] + cal[1][2]\n",
" return tensor([c1,c2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function returns the coordinates as a tensor of two items:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([384.6370, 259.4787])"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_ctr(img_files[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can pass this function to `DataBlock` as `get_y`, since it is responsible for labeling each item. We'll resize the images to half their input size, just to speed up training a bit.\n",
"\n",
"One important point to note is that we should not just use a random splitter. The reason for this is that the same people appear in multiple images in this dataset, but we want to ensure that our model can generalize to people that it hasn't seen yet. Each folder in the dataset contains the images for one person. Therefore, we can create a splitter function that returns true for just one person, resulting in a validation set containing just that person's images.\n",
2020-02-28 19:44:06 +00:00
"\n",
"The only other difference from the previous data block examples is that the second block is a `PointBlock`. This is necessary so that fastai knows that the labels represent coordinates; that way, it knows that when doing data augmentation, it should do the same augmentation to these coordinates as it does to the images:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-03-06 18:19:03 +00:00
"biwi = DataBlock(\n",
" blocks=(ImageBlock, PointBlock),\n",
" get_items=get_image_files,\n",
" get_y=get_ctr,\n",
" splitter=FuncSplitter(lambda o: o.parent.name=='13'),\n",
" batch_tfms=[*aug_transforms(size=(240,320)), \n",
" Normalize.from_stats(*imagenet_stats)]\n",
")"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"> important: Points and Data Augmentation: We're not aware of other libraries (except for fastai) that automatically and correctly apply data augmentation to coordinates. So, if you're working with another library, you may need to disable data augmentation for these kinds of problems."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Before doing any modeling, we should look at our data to confirm it seems okay:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
2020-09-03 18:29:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAFUCAYAAABPx8fsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9V7AlSZrfif1cRMQRV2Xe1JVZuqpFTevRMxjRmCHkEiBhAIYAbbkG7i7JJcGHfaQZjEYjH/iyRpotaUY+0IhdcjmwxXJBAAR2MArTg0HPTI/q7unualVdMrW48pwTwt0/Prh7RJyborq7qitnaelpN8+958QJ4f75J/6fUiLCk/FkPBlPxpPxZDwZ9w/9uG/gyXgynown48l4Mv6sjidC8sl4Mp6MJ+PJeDIeMp4IySfjyXgynown48l4yHgiJJ+MJ+PJeDKejCfjIeOJkHwynown48l4Mp6Mh4wnQvLJeDKejCfjyXgyHjLsoz48OjoWCZoQBO893/72twB48cWX0FqhtY6vRrFarZjNZnjv0cpQlBZjNEo9+Nz3J54IMnpXiwICwTe889brvP36W2zvnOHw6JDDg3ssF0fUdU3XtXSdp20dTd2yWtUcLxa03QoIWDvFlhPKyQRtC6ydMJud5vz5S1y4cIGz586ztbPJbFZRTRTWEJ8JjeBxXUvXeL7+9a9z594d6rrlwvmL+NBx8eIFUIG68RwvliyOl+zt73Pv3h0O7+2xONynWe3RtktcFwBPURZszLfY2NxiOptRTQq0NQQXsEYjAkUxRRuLAlScDFQ/S8Q5FYWIAhUgKFBCoCP4gDEGROE9iAjBd7iupWlr6uUxtfc8e+U5RMALSIjHiaj4GoQgwn/wH/39h6zeD3a89u1viy0sRVlSliVFUWCtxjnHcrVitaxZLVfUdU1dL6nrBXXT0DYNTdPRNC3OOQpbUVYTqqqgrErKScVkMmEymTCfTJlOpsxmM2azKQDL1YrlYslqVbNarVitaurViqZp4k/b0LYNnXNYW1CVFWVlKYqCsppSFiVFWWC0wZiCophgClBacB20raPrOpxzdF2L61qc83Suw3UO7z3OeUJwtK5DJO47nCMETxDBB4MED8ETQkcIgRCE4AOd63o6CV5wzhFCQKQleEfwQhBPEIeIJ4SAl4AESevv4u8BRDwSXKSfdJz3If7uA77z6LLiJ378M2ilcE4STcY9LCIEht972hJBxEMQRHQ6zvefEYT/7f/+P/nA6e6X/4v/XIy2GFPQtQ3WahRxv0kACSHOjwjegxcHAkoE8ULnPCFA1wV8CATicwoBkUCQgFIaFCjifChlEAFtFEobtIr8VAHKGLQ2KKUxRtCmQGuLVhowEDRKKYy2KG0wVqOtQilNURRxzxQWYw3GGIqiQBuDURatDMbEexEVEAJag8agtSKEAAhKgVYqPgOBxeKQN15/nTfeeJOdnW1efPFFTp3a5s033uSbr72G1vD80y9x7vw5Vsslt2/foalbNrc2mU0rvOto2462aXC+xbuGrmvwweF9G2nc+UhfzuG9wzmHdx3ee7xbIcFz+YUP86lPfpq2bREh0lZQCB5E4n5IPC0I8ZgQooSRQPCBo6MFEmBra442mr/+N37poTT3SCEZbyARL0JdN2gdjU+lVPqJxF4WE7Q2SBCMMaiHScc01pi+xHfU6ENRoNBoM+PS5Ze48vRLKF0QCDjX0q5qjo8OONzfY3/vNnv37rJcHFM3HXViRk3bUi866tWS+uiAzrcgnn1tuPF2yavVnHKyQTXZYmPzFDunznB6d5fd3dNsbk2ZziaUlaawmpc/+mFecI6u8xwfHXP9+lW+/vVXEfGsljWXr1xhWhl2Ll/i4tmzHC2XHB8vONrf5/DgNkeH96gXxzTNgr29Pfb39rDFhGpasbExYT6bUU503AxKE4JDaY1CQxC0UqAUSsV1QXScMQEhQM+AQmI+Ic5j/75EoRhCnFelccEjooHhcxF5gALzwQ5jDEYbtDIodLrfgtl0ymy6CafjJo5DoqKQ/lSi+k+UkkRoiaB4ME16FZ95a1aytbuTThtZpE6nFhGcgPM+CrMOnDOD8lE7urajaRuaxtEuOur6iKauaboVXdfRtQ3O1zjncK7Ddy0uMwLf4n2XhFBkEBIcPnRI5/u1DSKE4CB4RKLwjMIJgndp3w0CSwCC6xk2gIREN6IQJDLwPF8CiUf2SmsUxMOPl4B4QY8IRecpI3/34SMuiQLR8TgR4qpJf48f9Dh7bpOuFU6fOs/du3ucP38GrXV6ZoZ5F0UICo8jeI94jwrxWUJI+1AJEBDxON8lxh1QKgmgpHBIIJ07KkgheDzxO851eN8gXtM5F4/zcS2sNtR1jSnUaI9rEBV5hlIoFa8XeYghWytKCVprBBX3ltZoDcootEr8RukooJOQ1joaDsqAVfDyC8+htGF50FAf3WZip3zsox9HKTCqZHG0pG0jnYsI2iiqScFy2XL95g1u377DxYsXuHLlabRW3Lh+nes3bjCbzTh3cZfJxLJYrFgsjnGuQ6s4f6tVTdfsYWyVlC6V+F0AdL8Pop4WKTA+tQw8Ms3/bDJJAtRHY+MR49FCMkAIkrQooSgs8/nGcFPp8gSh6xxt1zKpbLo/C8q8K3Gq/r/xLzL8rwRdGFCJCJShKCeU5YT51g7nn3oGxONcS71acLS/x/7eXQ7273F4cMDR8ZKmaaK12TjqZsVqtcJ3Hft7d1gcHSByHWU0tigoyinlpKKsNphvbLK1vcPO6R22N7fZ3NxgMplQFAWXLj3NxUuXCKHl8PCY5fGC69ev07Y1GxsbIHB2d4fTOxV1e4rloub4aMHx0QFHh3dZHh/SNg2r5ZLjoyOMUpRVYDbbZmNrk+l0SlVNMNailSIom7RPicIyCQdFgMRi0h5NDJBkRUjP8CNBCBA3Ui9mZG0RHvu48vQlQCNq8AYoUT3nDQwCINKn7p+3CVGQRKYTBYDzDnGe4CJzci5A6/Hp785B2wltm628Dtd5OudYJS3WOYdvWrqmoXMdXVfj/Arn2ijgQof3kXH6kCw/Hy08xBOkIwTfr0EIAXoLPik3Sdul37wdIh4zstBCpgEJONcR0FhbxI0uiTECOjFLAKWiwqF1Zp7pPa0TU41MkMw6tY1MVaUJT/cYQrpPBDxQ2HjODG+kV1EKgqDwvcUUvxCVlSgSszCJ1+xrmqjH4wH6kR/+GZxzTKczFosVm9tbWGt7AyHyPAZLmWEdg/O9NR9CiIpU8ATfslwec3i45PSpMyilonXvGyT4xF+jJe99tKCCJEs0hHiO0CJBEYLHeUfnAkprVssFVVklpWqg5bxG3kd69CFEFCFE5U4EfJe+430U9IRIUyHzgLg+USnXgI/GUFqrKHAVWmfkwhKwaA1GmfSZTt+33LlxwN2bYIxmY2rZfOYplFLcu3MDrRRGC+fP7VIUtkfTmqblzp19bGG5dOkC29vb3Lm7z5uveV5//Tqbs23Onz+H84579+6xWtZsbGyytbXZKyPZIs7rNqylQwgsFzWbWxPkXfSydxGSqtdeRWBvb5/Ll6+kzxKRx7nEe4+x8eZE6eGG0jR+byw4bnQRicxcotWa8Y9spKq4O0EZimJKUUzZ3N7l0jMvICFCpavlgqOjAw727nGwt8f1G1c5ODpClGLn3GkInrZtWS1a6rqhWe6xOoyMWBeawhqsLSjKkmKyxWy2xXzjNJubm8zmE+azgqK0FJM5z77wIkbAuZYbN65z+9Ytjg6PUUqxdWqbU6dn7J6d0zTnWC5XLI6POdq/x+L4MFocqwXNcp+9O3cxRcVsNmNja87GbMpkMsPYCJmAQlRIsxp6os1EEUIYCcBsgXSExLx1USbGl2loXTl5tB3wgx9f/toN2q6jdQ1t29K1HaEVnAt0XYd3jq5zyTrzeBehxc61BBc/d84TfBehRlqca1CoxChCFFgZgknKRJAIiQffke2ogRFGK80nphYFsEqaqaAlwvNCFxlqEuJxehWobPm7uG9UgRGVIK0Es+nIlJRSiClRahqZCEmTVxptdDxOC8ZEKCwiLgW
2020-02-28 19:44:06 +00:00
"text/plain": [
"<Figure size 576x432 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dls = biwi.dataloaders(path)\n",
"dls.show_batch(max_n=9, figsize=(8,6))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"That's looking good! As well as looking at the batch visually, it's a good idea to also look at the underlying tensors (especially as a student; it will help clarify your understanding of what your model is really seeing):"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3, 240, 320]), torch.Size([64, 1, 2]))"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb,yb = dls.one_batch()\n",
"xb.shape,yb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Make sure that you understand *why* these are the shapes for our mini-batches."
2020-04-28 17:12:59 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-02-28 19:44:06 +00:00
"Here's an example of one row from the dependent variable:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-11-29 14:51:47 +00:00
"TensorPoint([[-0.3375, 0.2193]], device='cuda:6')"
2020-02-28 19:44:06 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yb[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As you can see, we haven't had to use a separate *image regression* application; all we've had to do is label the data, and tell fastai what kinds of data the independent and dependent variables represent."
2020-02-28 19:44:06 +00:00
]
},
2020-03-03 23:04:23 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"It's the same for creating our `Learner`. We will use the same function as before, with one new parameter, and we will be ready to train our model."
2020-03-03 23:04:23 +00:00
]
},
2020-02-28 19:44:06 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Training a Model"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"As usual, we can use `cnn_learner` to create our `Learner`. Remember way back in <<chapter_intro>> how we used `y_range` to tell fastai the range of our targets? We'll do the same here (coordinates in fastai and PyTorch are always rescaled between -1 and +1):"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"learn = cnn_learner(dls, resnet18, y_range=(-1,1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`y_range` is implemented in fastai using `sigmoid_range`, which is defined as:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"def sigmoid_range(x, lo, hi): return torch.sigmoid(x) * (hi-lo) + lo"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"This is set as the final layer of the model, if `y_range` is defined. Take a moment to think about what this function does, and why it forces the model to output activations in the range `(lo,hi)`.\n",
2020-02-28 19:44:06 +00:00
"\n",
"Here's what it looks like:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
2020-11-29 14:51:47 +00:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jhoward/anaconda3/lib/python3.7/site-packages/fastbook/__init__.py:55: UserWarning: Not providing a value for linspace's steps is deprecated and will throw a runtime error in a future release. This warning will appear only once per process. (Triggered internally at /pytorch/aten/src/ATen/native/RangeFactories.cpp:23.)\n",
" x = torch.linspace(min,max)\n"
]
},
2020-02-28 19:44:06 +00:00
{
"data": {
2020-09-03 17:46:15 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD7CAYAAABwggP9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAoHklEQVR4nO3deXxV1bn/8c8DYQiEEAIhjGFGRpmCsxVr61yhYlsUEaRKRW21t7Xqrd62atX6a+3oRKvizNWKVetUZ4taMQwBQQjIEAYhA5CRJCR5fn8k8caYEAI72SfJ9/16nRc5a6+9eJKcc56svdbay9wdERGRmtqEHYCIiEQmJQgREamVEoSIiNRKCUJERGqlBCEiIrWKCjuAIPXo0cMHDhwYdhgiIs3KsmXLstw9oWZ5i0oQAwcOJCUlJewwRESaFTPbWlu5LjGJiEitAk0QZna1maWYWbGZLayn7o/NbJeZ5ZjZQ2bWodqxeDN7zswKzGyrmV0UZJwiIlK/oHsQO4HbgIcOVsnMzgBuAE4DBgKDgV9Vq3IPUAIkAjOB+8xsdMCxiojIQQSaINx9sbv/A8iup+ps4EF3X+Pue4FbgTkAZtYZmA7c7O757r4EeAGYFWSsIiJycGGNQYwGUqs9TwUSzaw7MBwoc/e0GsfVgxARaUJhJYgYIKfa86qvu9RyrOp4l9oaMrN5leMeKZmZmYEHKiLSWoWVIPKB2GrPq77Oq+VY1fG82hpy9wXunuzuyQkJX5nGKyIihymsdRBrgHHA05XPxwG73T3bzIqAKDMb5u4bqh1fE0KcIiIRpaS0nIy8InbnFpGRW1zxb14xPzhlCF2j2wX6fwWaIMwsqrLNtkBbM+sIlLp7aY2qjwILzewJ4HPgJmAhgLsXmNli4BYzuwwYD0wFTggyVhGRSJSz/wDb9hSyfe9+tu8tZMe+/Xy+r4idOfvZua+I7IJiam7j07aNMW1C38hOEFR80P+i2vOLgV+Z2UPAWmCUu6e7+6tmdhfwNhANPFvjvCupmCqbQcWMqPnurh6EiLQIeUUH+CyzgM1Z+WzOKmRzVgFbswtI31PIvsIDX6rbqX1b+sZF0zsumlG9Y+nVtSO9YjuS2LUjiV060jO2A/Gd2tOmjQUep7WkHeWSk5Ndt9oQkUiRX1zK+l25rN+VT9ruPDZk5LExI5/ducVf1Glj0LdbNAO7dyYpvhMDuneif7dO9I/vRL9u0XSNbodZ8B/+1ZnZMndPrlneou7FJCISlqz8Yj7ZkcOanbl8siOHtZ/nsjW78Ivjndq3ZVjPGE4amsDQnjEMSejM4IQY+sdH0yGqbYiR100JQkSkgYpLy/hkRy7Lt+5l5fZ9rEzfx459+784PrB7J0b3ieWCif0Y2TuWo3p1oW9cdKNcBmpMShAiIvXILy5l2da9fLQpm6Wb97BqRw4lpeUA9I2LZnxSHHNOGMjYfl0Z1SeW2I7BDhaHRQlCRKSGktJyVqTv5f2NWSzZmEXq9hzKyp22bYyxfbsy+/gBTBrQjYkDutGzS8eww200ShAiIsCunCLeWpfBO+sz+OCzbPKLS2ljcHS/OK44ZTDHDe7OxKRudO7Qej42W893KiJSjbuTtjuf19bs4o1Pd7Nqe8UdfvrGRXPe+D6cMjyB4wZ3D3xtQXOiBCEirYa78+nneby0eievrN7FpqwCzGBC/zh+duZRfGNkIsN6xjT6tNLmQglCRFq8bXsKeSF1J/9YsYMNGfm0bWMcNzieuScN4vTRiS16HOFIKEGISItUWFLKK6t38fdl2/lwU8UWNZMHduPWaWM4e0wvusd0qKcFUYIQkRZl7c5cnly6lX+s2El+cSkDunfiJ98czrcn9qVft05hh9esKEGISLN3oKycVz7ZxcL3N7M8fR/to9pw7tjezDgmickDu2lM4TApQYhIs7WvsIQnPkrn0Q+3sDu3mIHdO3HTOSO5YFI/4jq1Dzu8Zk8JQkSanR379vPgvzez6ON0CkvKOHlYD+44fyxThvdsdreziGRKECLSbGzNLuDetz/j2eXbAThvXB8u/9pgRvauuQmlBEEJQkQiXnp2IX94M43nV+6kbRvj4uMGcPnXBtM3Ljrs0Fo0JQgRiVi7c4v481sbWLR0G23bGHNOGMgPvjaYnrFat9AUgt5yNB54EDgdyAJudPcna6l3PxW7zVVpB5S4e5fK4+8AxwFVW5XucPejgoxVRCJXfnEp97/zGX9bsonSMufCY5K4+utDSVRiaFJB9yDuAUqARCr2kn7JzFJrbhfq7lcAV1Q9N7OFQHmNtq52978FHJ+IRLCycmfRx+n8/vU0svJLOG9cH356+lEkddf6hTAEliDMrDMwHRjj7vnAEjN7AZgF3HAI550bVCwi0vykbNnD/zy/hrWf5zJ5YDf+Nnsy4/vHhR1WqxZkD2I4UObuadXKUoFT6jlvOpAJvFej/A4zuxNYD/zc3d+p7WQzmwfMA0hKSjqMsEUkTJl5xdzx8qcsXrGD3l078peLJnDO2N5a3BYBgkwQMUBOjbIcoEs9580GHnV3r1Z2PbCWistVM4AXzWy8u39W82R3XwAsAEhOTvaax0UkMpWXO0+nbOP2lz+l6EA5V506hKtOHUqn9po7EymC/E3kAzUnI8cCeXWdYGb9qehhXF693N0/qvb0ETO7EDgb+HMwoYpImDZl5nPDs6tZumUPxw6K59ffHsvQnjFhhyU1BJkg0oAoMxvm7hsqy8YBaw5yziXAB+6+qZ62HVB/U6SZKyt3Hlqymd/+az0d27XlrulH853kfrqcFKECSxDuXmBmi4FbzOwyKmYxTQVOOMhplwC/qV5gZnHAscC7VExz/R7wNeDaoGIVkaa3OauAnzy9kuXp+/jGyERu//YYrWeIcEFf7LsSeAjIALKB+e6+xsySqBhTGOXu6QBmdjzQD3imRhvtgNuAEUAZsA6Y5u7rA45VRJqAu/PU0m3c+s+1tI9qw++/N45p4/uq19AMBJog3H0PMK2W8nQqBrGrl30IdK6lbiYwOci4RCQcewpKuP7ZVby+djcnDu3O774znl5d1WtoLjRdQEQaxUebsvnRohXsLTjATeeMZO6Jg3Sn1WZGCUJEAlVe7tz7zkbufj2NAd078+DsyYzp2zXssOQwKEGISGD2FZZwzaKVvJuWyXnj+nD7+WOJ6aCPmeZKvzkRCcQnO3K44vFlZOQW8+tvj+GiY5I0EN3MKUGIyBFbvHw7Ny5eTXzn9jx9xfG6h1ILoQQhIoetrNy569V1PPDeJo4bHM9fLppIj5gOYYclAVGCEJHDkld0gGsWreStdRnMOm4A//OtUbRr2ybssCRAShAi0mA79u1n7sMfszEzn1unjmbW8QPDDkkagRKEiDTIJztymLvwY/aXlPHIpcdw0rAeYYckjUQJQkQO2dvrMrjqyeV069Sex+Yfy1G96rubvzRnShAickj+vmw71z+7ihG9uvDwnMm60V4roAQhIvV64N3PuOOVdZw0tAf3z5qkxW+thH7LIlInd+fOVyqmsZ57dG9+991xdIhqG3ZY0kSUIESkVuXlzk3Pf8KTH6Uz67gB/PK80bTVzfZaFSUIEfmK0rJyrvv7Kp5bsYP5U4bwszOO0m0zWiElCBH5kpLScn701ApeXbOL6844iqtOHRp2SBKSQJc9mlm8mT1nZgVmttXMLqqj3hwzKzOz/GqPKQ1tR0SCVVJazlVPLufVNbu4+dxRSg6tXNA9iHuAEiCRij2pXzKzVHdfU0vdD939pADaEZEAFJeWcdUTy3nj0wx+dd5oZp8wMOyQJGSB9SDMrDMwHbjZ3fPdfQnwAjArjHZE5NCVlJZ/kRxumarkIBWCvMQ0HChz97RqZanA6DrqTzCzLDNLM7ObzayqN9OgdsxsnpmlmFlKZmbmkX4PIq3OgbJyfvhURXK4ddoYLtF9laRSkAkiBsipUZYD1LYW/z1gDNCTit7ChcB1h9EO7r7A3ZPdPTkhIeEwQxdpncrKnf96OpXX1uzmf84dxazjBoQdkkSQIBNEPhBboywWyKtZ0d03uftmdy9399XALcAFDW1HRA5
2020-02-28 19:44:06 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(partial(sigmoid_range,lo=-1,hi=1), min=-4, max=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We didn't specify a loss function, which means we're getting whatever fastai chooses as the default. Let's see what it picked for us:"
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FlattenedLoss of MSELoss()"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.loss_func"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"This makes sense, since when coordinates are used as the dependent variable, most of the time we're likely to be trying to predict something as close as possible; that's basically what `MSELoss` (mean squared error loss) does. If you want to use a different loss function, you can pass it to `cnn_learner` using the `loss_func` parameter.\n",
2020-02-28 19:44:06 +00:00
"\n",
"Note also that we didn't specify any metrics. That's because the MSE is already a useful metric for this task (although it's probably more interpretable after we take the square root). \n",
"\n",
2020-05-14 12:18:31 +00:00
"We can pick a good learning rate with the learning rate finder:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2020-04-28 17:12:59 +00:00
"text/plain": [
2020-09-03 18:29:20 +00:00
"SuggestedLRs(lr_min=0.005754399299621582, lr_steep=0.033113110810518265)"
2020-04-28 17:12:59 +00:00
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-04-28 17:12:59 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2020-09-03 18:29:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEQCAYAAAB80zltAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAA320lEQVR4nO3deXyU5bnw8d+VhSRkhWzsIEvYDUgUBAUUBLFaF9T3qHXpaavVY9vT1ta259BaW+3y9u1p9agtrbVqLbVWqBu4sAviEheWQAj7ng1IyL5MrvePZ0KHYbKR2ZJc389nPmbuZ7uekeSae3nuW1QVY4wxxp8iQh2AMcaY7seSizHGGL+z5GKMMcbvLLkYY4zxO0suxhhj/M6SizHGGL+LCnUA4SAtLU2HDRsW6jCMMaZL+fjjj0tVNd3XNksuwLBhw8jNzQ11GMYY06WIyIGWtlmzmDHGGL+z5GKMMcbvLLkYY4zxO0suxhhj/M6SizHGGL+z5GKMMcbvLLkYY0wPdfhkNTX1roCc25KLMcb0UD9Yto3rntgYkHNbcjHGmB6o6FQtG3aVcMW4zICc35KLMcb0QK98doQmhesvGBiQ81tyMcaYHkZVefnjI0wanMKI9ISAXMOSizHG9DDbj51iZ1EFCwNUawFLLsYY0+Ms/eQI0ZHCNdkDAnYNSy7GGNODNLqaeOWzI8wZk0lK714Bu44lF2OM6UHe3VVKaWU9NwSwSQyCmFxEpK+ILBORKhE5ICK3trDfXSLiEpFKj9dsj+2VXi+XiDzu3jZMRNRr+6Lg3KExxoS/f3xymD69o5k9OiOg1wnmYmFPAPVAJjAJeENENqtqno99N6nqJb5OoqqnhzaISDxQBLzktVuKqjb6JWpjjOkmTlbV8872Im65cDC9ogJbtwhKzcWdBBYCi1S1UlU3AK8Ct3fy1DcCxcC7nTyPMcZ0ey99fIj6xiZunTo04NcKVrNYFuBS1QKPss3A+Bb2nywipSJSICKLRKSlGtadwHOqql7lB0TksIg8IyJpnYzdGGO6vKYm5YUPDnLhsD6M7pcY8OsFK7kkAOVeZeWArztcD0wAMnBqO7cA3/HeSUSGALOAZz2KS4ELgaHAFPf5X/AVkIjcLSK5IpJbUlLSoZsxxpiuZsPuUg4cr+YL0wJfa4HgJZdKIMmrLAmo8N5RVfeq6j5VbVLVrcDDOM1f3u4ANqjqPo9jK1U1V1UbVbUIuB+YJyLe10ZVF6tqjqrmpKend+LWjDEm/P3l/QOkxvfiygn9gnK9YCWXAiBKREZ5lGUDvjrzvSkgPsrv4MxaS0vH0sLxxhjTIxwrr2HljiJuyhlMTFRkUK4ZlOSiqlXAUuBhEYkXkRnAtcDz3vuKyAIRyXT/PAZYBLzitc90YCBeo8REZKqIjBaRCBFJBR4D1qqqd5OcMcb0GEs+PIQCt00dErRrBvMhyvuAOJzRXUuAe1U1T0SGuJ9Hab7rOcAWEakCluMkpUe9znUnsFRVvZvVhgNv4jS3bQPqcPpsjDGmR2pwNfG3Dw8yKyudwX17B+26QXvORVVPANf5KD+I0+Hf/P4B4IE2znVPC+VLcBKXMcYYYOPuUoor6vjpRcGrtYBN/2KMMd3aqh3FxEVHMjMruAOXLLkYY0w3paqs2lHEjJFpxEYHpyO/mSUXY4zppvILKzhaXsvcsYGdR8wXSy7GGNNNrdpRBMDlYyy5GGOM8ZOVO4rJHpRMRlJs0K9tycUYY7qhkoo6Nh8uY87YzJBc35KLMcZ0Q2t2FqMKc0LQ3wKWXIwxpltataOI/smxjOt/1tSKQWHJxRhjupnaBhfv7irl8jEZiIRmakVLLsYY0818sO8E1fUu5oaovwUsuRhjTLfz5rZjxPeK5OIRqSGLwZKLMcZ0Iw2uJlZsK+SKcZlBfyrfkyUXY4zpRjbuLqWsuoHPnT8gpHFYcjHGmG7k9S3HSIyNYmZWWkjjsORijDHdRF2ji7fyCpk3rl/QVpxsiSUXY4zpJt4tKKWitpGrs/uHOpTgJRcR6Ssiy0SkSkQOiMitLex3l4i43KtTNr9me2xfKyK1Htt2eh0/R0TyRaRaRNaIyNDA3pkxxoSH17ccJaV3NJeMDG2TGAS35vIEUA9kArcBT4nI+Bb23aSqCR6vtV7b7/fYNrq5UETScJZFXgT0BXKBF/19I8YYE25qG1y8s72IK8f3Izoy9I1SQYlAROKBhcAiVa1U1Q3Aq8Dtfr7UDUCeqr6kqrXAQ0C2iIzx83WMMSasrN1ZTFW9i6tDPEqsWbDSWxbgUtUCj7LNQEs1l8kiUioiBSKySESivLb/zL19o2eTmft8m5vfqGoVsKeV6xhjTLfw2pZjpMb3YtrwvqEOBQheckkAyr3KyoFEH/uuByYAGTi1nVuA73hsfxAYDgwEFgOviciIjl5HRO4WkVwRyS0pKenY3RhjTBipqmtk1Y4irprYn6gwaBKD4CWXSsB7as4koMJ7R1Xdq6r7VLVJVbcCDwM3emz/QFUrVLVOVZ8FNgJXncN1FqtqjqrmpKenn/ONGWNMqK3cUURtQxPXZIdHkxgEL7kUAFEiMsqjLBvIa8exCrQ2rafn9jz3eYHTfT0j2nkdY4zpkl7bfJT+ybHkDO0T6lBOC0pycfd9LAUeFpF4EZkBXAs8772viCwQkUz3z2NwRn694n6fIiLzRSRWRKJE5DZgJvCW+/BlwAQRWSgiscAPgS2qmh/oezTGmFAor25gXUEJV5/fn4iI0Eyv70swG+fuA+KAYmAJcK+q5onIEPfzKkPc+80BtohIFbAcJyk96t4WDfwUKAFKga8B16nqTgBVLcHpp3kEOAlMBf4tGDdnjDGh8FZeIQ0uDasmMQDvUVgBo6ongOt8lB/E6Yhvfv8A8EAL5ygBLmzjOisBG3psjOkRXt18lKGpvZk4MDnUoZwhPIYVGGOM6bCSijre21PKNecPCNmKky2x5GKMMV3Uim3HaFL4/KTwahIDSy7GGNNlvbb5KKMzE8nK9PXIYGhZcjHGmC6ouKKW3AMnuWpi6GdA9sWSizHGdEGrdhSjCvMnZIY6FJ8suRhjTBf0dl4hg/vGMToMm8TAkosxxnQ5lXWNbNxznHnj+oXdKLFmllyMMaaLWV9QQn1jE1eMC88mMbDkYowxXc4724tI6R0dVnOJebPkYowxXUiDq4nV+cXMGZMZNtPr+xK+kRljjDnLR/tOUF7TENZNYmDJxRhjupS3txcRExXBzKy0UIfSKksuxhjTRagq72wv4tJRafTuFbR5h8+JJRdjjOki8o6e4khZTdg3iYElF2OM6TJWbDtGZIRwxbh+oQ6lTZZcjDGmC1BVlm8t5OLhqfSN7xXqcNpkycUYY7qAnUUV7CutYsHE8K+1QBCTi4j0FZFlIlIlIgdE5NYW9rtLRFzupY+bX7Pd22JE5Gn38RUi8qmILPA4dpiIqNexi4Jzh8YYEzjLtxYSITCvCzSJQRCXOQaeAOqBTGAS8IaIbFbVPB/7blLVS3yURwGHgFnAQeAq4O8iMlFV93vsl6Kqjf4M3hhjQmnF1mNcdF5f0hNjQh1KuwSl5iIi8cBCYJGqVqrqBuBV4PaOnEdVq1T1IVXdr6pNqvo6sA+Y4v+ojTEmPOwqqmBXcWXYrt3iS7CaxbIAl6oWeJRtBsa3sP9kESkVkQIRWSQiPmtYIpLpPrd37eeAiBwWkWdExOeTRiJyt4jkikhuSUlJB2/HGGOCZ8W2QkRg/viu0SQGwUsuCUC5V1k54GshgvXABCADp7ZzC/Ad751EJBp4AXhWVfPdxaXAhcBQnNpMonufs6jqYlXNUdWc9PT0Dt+QMcYEy/Ktx8gZ2ofMpNhQh9JuwUoulUCSV1kSUOG9o6ruVdV97mavrcDDwI2e+4hIBPA8Th/O/R7HVqpqrqo2qmqRe9s8EfG+tjHGdAl7SyrJL6xgwYSu0yQGwUsuBUCUiIzyKMvm7OYsXxQ4vRqOOCvjPI0zMGChqja0cSyexxtjTFfyZl4hAFdO6DpNYhCk5KKqVcBS4GERiReRGcC1OLWPM4jIAndfCiIyBlgEvOK
2020-02-28 19:44:06 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll try an LR of 1e-2:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.049630</td>\n",
" <td>0.007602</td>\n",
" <td>00:42</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
2020-04-28 17:12:59 +00:00
" </thead>\n",
" <tbody>\n",
2020-02-28 19:44:06 +00:00
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>0</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.008714</td>\n",
" <td>0.004291</td>\n",
" <td>00:53</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>1</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.003213</td>\n",
" <td>0.000715</td>\n",
" <td>00:53</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" <tr>\n",
2020-04-28 17:12:59 +00:00
" <td>2</td>\n",
2020-11-29 14:51:47 +00:00
" <td>0.001482</td>\n",
" <td>0.000036</td>\n",
" <td>00:53</td>\n",
2020-02-28 19:44:06 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2020-04-28 17:12:59 +00:00
"lr = 1e-2\n",
"learn.fine_tune(3, lr)"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"Generally when we run this we get a loss of around 0.0001, which corresponds to an average coordinate prediction error of:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.01"
]
},
2020-12-26 17:25:42 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.sqrt(0.0001)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"This sounds very accurate! But it's important to take a look at our results with `Learner.show_results`. The left side are the actual (*ground truth*) coordinates and the right side are our model's predictions:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-12-26 17:25:42 +00:00
"execution_count": null,
"metadata": {},
2020-02-28 19:44:06 +00:00
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2020-11-29 14:51:47 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAHzCAYAAACDns4pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9ebBlSX7fh31+mefc+9Zauqq7qrun92X2wTbEAMROEoIhCgxKlBkhyVLQDjls2RF2QA6JDppB0yGHQNqWLdkUrTApiqQl0iAJYEhsg4XDwQxmBrNjZrqn9+7ppaq6q2t/2733nMyf//hlnnPufa9eVTe63x0A+eu+9e49ay7fzN+avxRVpVChQoUKHQ25ZRegUKFChf44UZl0CxUqVOgIqUy6hQoVKnSEVCbdQoUKFTpCKpNuoUKFCh0hlUm3UKFChY6QyqRb6I8VichfFxEVER0c+1Q69ql34PkP5ueLyF/6gz6v0B89KpPudyCJyLcHA/dmn7++7HJmGpT379/k/Pek839BRH78gLpsiciTIvJXRWT9iIsP8C3gC+nvbdFBk3eiaXrWF4A337kiFvqjQtWyC1DoQPoa8Hr6/h7g3vT997FBDfDaW32oiAjgVbX9gxbwLdKfBybAJ4A/MTj+IjYx3Q98APjPgO8H/tzNHiQiI1WdvZOFU9X/1Tv4rAvAD7xTzyv0R5BUtXy+gz/AXwc0fR5Mx/6vwJPANaABzgP/ALj7Jvf9NCbFtcB3AwL8NeANYAv4/wL/28X3pOf8a8AngRvAHibB/Uw69+DgnrnPQh2+Dvxy+v7jg+v+Ujrmgd8bHD+58Oz/FPg4sAv8l+mes8DfBc4BM+Bl4G8A48F7R8DfSu10BfivgP98sYzAp9KxTy3c+1dSO0+A68BngccH1y9+/tJCuf/S4HkfAn4RuJTK+xLwfwM2DioH8L8Gvp3651eAs8vGYvm8M58i6f7hpJ/GpN9XMW3lvcB/ALwfkxQX6ePY5HQ+/f6PgP9T+v468BPAv7l4k4j828A/wSbp17DJ5/uBfy4ifxGbhL4AfA82SV0CXlh4xoPAR4D/51uo36LK/p+ld78AtCJyCpukHwB2gKeA9wF/Gfgg8DPpvv8zNnmBTWD/DrB2m2X4BeDfSN/fwCbt7wfuwRjYo/QayBfS3wPNCSLyfuDzwEYq7/NYn/3vgB8UkR9R1Ti45U8CH8P6dwP4s8B/Afx7t1n2Qt/JtOxZv3wO/3CwpPsRwA2u+Q8H1zxywH1/Y3CtB15Jx78EjIEa+PQB73kx/f4fAEnH/k469tzgmd9Ox/7+AeX/WSAAd6XfPz54zwvY5HlucOxfpOseHBx7CjgxKP9fS8cvk6R74IcG1/8QNrnupd+/iDGOdeDpfN2gjJ9iIOkCPzp41t/GTDJgE+7ZxfZdqO+w3H8pHfsH6fcO8EA69r8cXPczC+UIwHenY7+Yjr2+bCyWzzvzKY60P5z0XcCXRGQ7OXL+zuDcPQdc/18Ovq8D96Xvv6SqU1VtgH82vEFE7gQeSj//XSCmd/2H6dijSeK8Ff154HOqevGAcw9jEt1xTHr8a5g0ukj/QFWvAahqSPcA3AGcT+X63cH1P4BJoivp9z9Vox3gV2+jzB8bfP+b6Z2o6nlVff0m9xxG2Y79WVV9OX3/R4PzH124/puq+vvpe3bu3fU23lvoO5CKeeEPGYnID2OSk2CS3rcwFfT96RK/eM/CRKE3+b7vVYPvLwEHTZr1Lcp6CpM6//JNLvmfqurfP+wZiRYnuly2bczmukjXDivWbbzv3aLbTel3bfA9Oz2XWe5C7yAVSfcPH32MfgB+WFW/H/iHt3uzqm5h5gWAPycitYjUwL+9cN1FzGwA8ATwI6r6A6r6A8BfBH5uMJnvpr+L4V5/DmMCH7/d8t0mfTEXE/ifDMr1E5hz6hcwu+kkXfcXxGgNs4ffir4w+P6fiIgDEJGzInImHc915jbC3L6U/v6QiDyQvv+7g/Nfvo0yFfojQmXS/cNH3xh8/6aIPAX8J2/xGX8z/f0BbGJ9Cfi+A67736e/PwNcEJGvicj5dM/PDq57Ov39t0TkKyLy36Xffx54QlXnnGvvAP0tzMm0CXxLRL4hIs8BV4F/itl/d4H/Ol3/FzD79LfpTSY3JVX9NBYxAOaIOyciT2ARElmjeHpwy5Mi8nsi8vBNHvk3MKl8PV375KBsn+P2TB6F/ohQmXT/kJGq/hamrp8HVrHB/x+9xcf8v4H/I+ZtPw58Bvi5wfm99K6fxyTDT2LRCe/HpMd/ikmUmf4q5hCbAd8LfDhJlT/JOy/loqqXMIbxdzGzx/uBY5hE+VewaAOA/wNW1xtYGNrHsbCx26G/kO5/CrMd349JpDkC5FcwW/plLIriY9wkMkJVnwJ+EPglLM76cYxp/BfAT+l85EKhP+KUPdKF/hiRiBwHVlT1jfTbA7+OTZIXgHv1DwgMEfk3Mc/7R1X1K3/AIhcq9EeGiiPtjyc9BHxBRL6ExdZ+FxbqBPBX/qATbqId4K+WCbdQoXkqku4fQxKRe4G/hy1qOInZG78M/D9U9deWWbZChf6oU5l0CxUqVOgIqTjSChUqVOgIqUy6hQoVKnSEVCbdQoUKFTpCKpNuoUKFCh0hlUm3UKFChY6QyqRbqFChQkdIZdItVKhQoSOkMukWKlSo0BFSmXQLFSpU6AipTLqFChUqdIRUJt1ChQoVOkIqk26hQoUKHSGVSbdQoUKFjpDKpFuoUKFCR0hl0i1UqFChI6Qy6RYqVKjQEVKZdAsVKlToCKlMuoUKFSp0hFQm3UKFChU6QiqTbqFChQodIZVJt1ChQoWOkMqkW6hQoUJHSGXSLVSoUKEjpDLpFipUqNARUpl0CxUqVOgIqUy6hQoVKnSEVCbdQoUKFTpCKpNuoUKFCh0hlUm3UKFChY6QyqRbqFChQkdIZdItVKhQoSOkMukWKlSo0BFSmXQLFSpU6AipTLqFChUqdIRUJt1ChQoVOkIqk26hQoUKHSGVSbdQoUKFjpDKpFuoUKFCR0hl0i1UqFChI6Qy6RYqVKjQEVKZdAsVKlToCKlMuoUKFSp0hFQm3UKFChU6QiqTbqFChQodIZVJt1ChQoWOkMqkW6hQoUJHSGXSLVSoUKEjpDLpFipUqNARUpl0CxUqVOgIqTrs5M/93N/Un/2Pf5ZRXYEAms/YD9H8HVQi6UB3jaidHh7tSGXuWRBur8QCqOveCxAl5oem4zL/E7X/VYkqaHSoRkKMaFBCGwgh0rYtIQSapqVt7dO00LYNTTsjxpY2XRtCIMZgzwmBGCNRA1EDGluiBiQoRIga7d1BUBSN0f6m78RI2za0YQfnPKqCRiDa9TFEQmzQGFC1SqlGu04VpUEj/bNVUWZz74oxoNGuDzEQgt0Tgv2ObaBpAxubx/non/gQtC2hjURxKKDRelFTn8f8nqhdObouclh5VPtP6oMYA1Zl5Vd/9RN8/Fd/re/II6SC7YLtZWH70EkXUaraE1RBrTFEBHGaMNIjzwrnEgYUokBUFIiqxKhWmfRf66zAVtNceIhBcQp1V8++4jHa5apCjEJoIyFEmghRMXCEQIhKDMHeG1pCnHTAa0NLjG0qj9I2LTFENEaatqFpGmu8EO2vtoTQohrSX6uHNXC+JqbRF9KzA0pAQrA2SJ0jafCpgojQEuxHtPeDIuKIESCi0nTjy2kEtfYTBCWmAepQMmA1/2/tTwKrarpeUtu1RG2JIQOzgSbQtBFXg/OCqEtglzQXSP8OBOecPVdIz83vCXaNOCufhm6SUFUQkPxxS5lvjQq2C7aXhO1DJ90335zwzNPnmU5bQohMw4wQWuMkGgecRpGYypwARmszPgohRFQjMShRAyEElJaokRjtWaTGtXPgRBKHTR0QhRADIvMNhKpxW0gcKho3ztxKY2osrDU0lSXGxOXaxFkjMQ65ljW0Sxw1xNTgWPnQ/H43KE8CVIxAzL2R3g2hnaafHo0wbadUvqKuqjTiApA6VADXgy0jWkQQSe9RxYlDNU8WklHfvVPE4SQBOU0lzjmUmojVzzuHk8CoUlZ
2020-02-28 19:44:06 +00:00
"text/plain": [
2020-09-03 18:29:20 +00:00
"<Figure size 432x576 with 6 Axes>"
2020-02-28 19:44:06 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-09-03 18:29:20 +00:00
"learn.show_results(ds_idx=1, nrows=3, figsize=(6,8))"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's quite amazing that with just a few minutes of computation we've created such an accurate key points model, and without any special domain-specific application. This is the power of building on flexible APIs, and using transfer learning! It's particularly striking that we've been able to use transfer learning so effectively even between totally different tasks; our pretrained model was trained to do image classification, and we fine-tuned for image regression."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"In problems that are at first glance completely different (single-label classification, multi-label classification, and regression), we end up using the same model with just different numbers of outputs. The loss function is the one thing that changes, which is why it's important to double-check that you are using the right loss function for your problem.\n",
2020-02-28 19:44:06 +00:00
"\n",
"fastai will automatically try to pick the right one from the data you built, but if you are using pure PyTorch to build your `DataLoader`s, make sure you think hard when you have to decide on your choice of loss function, and remember that you most probably want:\n",
2020-02-28 19:44:06 +00:00
"\n",
"- `nn.CrossEntropyLoss` for single-label classification\n",
"- `nn.BCEWithLogitsLoss` for multi-label classification\n",
"- `nn.MSELoss` for regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. How could multi-label classification improve the usability of the bear classifier?\n",
2020-02-28 19:44:06 +00:00
"1. How do we encode the dependent variable in a multi-label classification problem?\n",
"1. How do you access the rows and columns of a DataFrame as if it was a matrix?\n",
"1. How do you get a column by name from a DataFrame?\n",
2020-05-14 12:18:31 +00:00
"1. What is the difference between a `Dataset` and `DataLoader`?\n",
"1. What does a `Datasets` object normally contain?\n",
"1. What does a `DataLoaders` object normally contain?\n",
"1. What does `lambda` do in Python?\n",
"1. What are the methods to customize how the independent and dependent variables are created with the data block API?\n",
2020-02-28 19:44:06 +00:00
"1. Why is softmax not an appropriate output activation function when using a one hot encoded target?\n",
2020-05-14 12:18:31 +00:00
"1. Why is `nll_loss` not an appropriate loss function when using a one-hot-encoded target?\n",
2020-02-28 19:44:06 +00:00
"1. What is the difference between `nn.BCELoss` and `nn.BCEWithLogitsLoss`?\n",
"1. Why can't we use regular accuracy in a multi-label problem?\n",
2020-05-14 12:18:31 +00:00
"1. When is it okay to tune a hyperparameter on the validation set?\n",
"1. How is `y_range` implemented in fastai? (See if you can implement it yourself and test it without peeking!)\n",
2020-02-28 19:44:06 +00:00
"1. What is a regression problem? What loss function should you use for such a problem?\n",
"1. What do you need to do to make sure the fastai library applies the same data augmentation to your input images and your target point coordinates?"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"1. Read a tutorial about Pandas DataFrames and experiment with a few methods that look interesting to you. See the book's website for recommended tutorials.\n",
"1. Retrain the bear classifier using multi-label classification. See if you can make it work effectively with images that don't contain any bears, including showing that information in the web application. Try an image with two different kinds of bears. Check whether the accuracy on the single-label dataset is impacted using multi-label classification."
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
2020-09-03 18:29:20 +00:00
}