fastbook/11_nlp_dive.ipynb
2020-02-28 11:44:06 -08:00

1314 lines
42 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"[[chapter_nlp_dive]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A language model from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're now ready to go deep... deep into deep learning! You already learned how to train a basic neural network, but how do you go from there to creating state of the art models? In this part of the book we're going to uncover all of the mysteries, starting with language models."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Whenever we start working on a new problem, we always first try to think of the simplest dataset we can which would allow us to try out methods quickly and easily, and interpret the results. When we started working on language modelling a few years ago, we didn't find any datasets that would allow for quick prototyping, so we made one. We call it *human numbers*, and it simply contains the first 10,000 words written out in English."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> j: One of the most common practical mistakes I see even amongst highly experienced practitioners is failing to use appropriate datasets at appropriate times during the analysis process. In particular, most people tend to start with datasets which are too big and too complicated."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can download, extract, and take a look at our dataset in the usual way:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.text.all import *\n",
"path = untar_data(URLs.HUMAN_NUMBERS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('train.txt'),Path('valid.txt')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's open those two files and see what's inside. At first we'll join all of those texts together and ignore the split train/valid given by the dataset, we will come back to it later on:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lines = L()\n",
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
"lines"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We take all those lines and concatenate them in one big stream. To mark when we go from one number to the next, we use a '.' as separation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = ' . '.join([l.strip() for l in lines])\n",
"text[:100]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use word tokenization for this dataset, by splitting on spaces:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens = text.split(' ')\n",
"tokens[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To numericalize, we have to create a list of all the unique tokens (our *vocab*):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab = L(*tokens).unique()\n",
"vocab"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can convert our tokens into numbers by looking up the index of each in the vocab:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
"nums = L(word2idx[i] for i in tokens)\n",
"nums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our first language model from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One simple way to turn this into a neural network would be to specify that we are going to predict each word based on the previous three words. Therefore, we could create a list of every sequence of three words as independent variables, and the next word after each sequence as the dependent variable. \n",
"\n",
"We can do that with plain Python. Let us do it first with tokens just to confirm what it looks like:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will do it with tensors of the numericalized values, which is what the model will actually use:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
"seqs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can batch those easily using the `DataLoader` class. For now we will split randomly the sequences."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs = 64\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create a neural network architecture that takes three words as input, and returns a prediction of the probability of each possible next word in the vocab. We will use three standard linear layers, but with two tweaks.\n",
"\n",
"The first tweak is that the first linear layer will use only the first word's embedding as activations, the second layer will use the second word's embedding plus the first layer's output activations, and the third layer will use the third word's embedding plus the second layer's output activations. The key effect of this is that every word is interpreted in the information context of any words preceding it. \n",
"\n",
"The second tweak is that each of these three layers will use the same weight matrix. The way that one word impacts the activations from previous words should not change depending on the position of a word. In other words, activation values will change as data moves through the layers, but the layer weights themselves will not change from layer to layer. So a layer does not learn one sequence position; it must learn to handle all positions.\n",
"\n",
"Since layer weights do not change, you might think of the sequential layers as the \"same layer\" repeated. In fact PyTorch makes this concrete; we can just create one layer, and use it multiple times."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our language model in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create the language model module that we described earlier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel1(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
" h = h + self.i_h(x[:,1])\n",
" h = F.relu(self.h_h(h))\n",
" h = h + self.i_h(x[:,2])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you see, we have created three layers:\n",
"\n",
"- The embedding layer (`i_h` for *input* to *hidden*)\n",
"- The linear layer to create the activations for the next word (`h_h` for *hidden* to *hidden*)\n",
"- A final linear layer to predict the fourth word (`h_o` for *hidden* to *output*)\n",
"\n",
"This might be easier to represent in pictorial form. Let's define a simple pictorial representation of basic neural networks. Here's how we're going to represent a neural net with one hidden layer:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img alt=\"Pictorial representation of simple neural network\" width=\"400\" src=\"images/att_00020.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each shape represents activations: rectangle for input, circle for hidden (inner) layer activations, and triangle for output activations:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img alt=\"Shapes used in our pictorial representations\" width=\"200\" src=\"images/att_00021.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An arrow represents the actual layer computation—i.e. the linear layer followed by the activation layers. Using this notation, here's what our simple language model looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img alt=\"Representation of our basic language model\" width=\"500\" caption=\"Representation of our basic language model\" id=\"lm_rep\" src=\"images/att_00022.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To simplify things, we've removed the details of the layer computation from each arrow. We've also color-coded the arrows, such that all arrows with the same color have the same weight matrix. For instance, all the input layers use the same embedding matrix, so they all have the same color (green).\n",
"\n",
"Let's try training this model and see how it goes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.824297</td>\n",
" <td>1.970941</td>\n",
" <td>0.467554</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.386973</td>\n",
" <td>1.823242</td>\n",
" <td>0.467554</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.417556</td>\n",
" <td>1.654497</td>\n",
" <td>0.494414</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.376440</td>\n",
" <td>1.650849</td>\n",
" <td>0.494414</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To see if this is any good, let's check what would a very simple model give us. In this case we could always predict the most common token, so let's find out which token is the most often the target in our validation set:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(29), 'thousand', 0.15165200855716662)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n,counts = 0,torch.zeros(len(vocab))\n",
"for x,y in dls.valid:\n",
" n += y.shape[0]\n",
" for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
"idx = torch.argmax(counts)\n",
"idx, vocab[idx.item()], counts[idx].item()/n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The most common token has the index 29, which corresponds to the token 'thousand'. Always predicting this token would give us an accuracy of roughly 15\\%, so we are faring way better!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> A: My first guess was that the separator would be the most common token, since there is one for every number. But looking at `tokens` reminded me that large numbers are written with many words, so on the way to 10,000 you write \"thousand\" a lot: five thousand, five thousand and one, five thousand and two, etc.. Oops! Looking at your data is great for noticing subtle featues and also embarrassingly obvious ones."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our first recurrent neural network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the code for our module, we could simplify it by replacing the duplicated code that calls the layers with a for loop. As well as making our code simpler, this will also have the benefit that we could apply our module equally well to token sequences of different lengths; we would not be restricted to token lists of length three."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel2(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = 0\n",
" for i in range(3):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that we get the same results using this refactoring:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.816274</td>\n",
" <td>1.964143</td>\n",
" <td>0.460185</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.423805</td>\n",
" <td>1.739964</td>\n",
" <td>0.473259</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.430327</td>\n",
" <td>1.685172</td>\n",
" <td>0.485382</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.388390</td>\n",
" <td>1.657033</td>\n",
" <td>0.470406</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also refactor our pictorial representation in exactly the same way (we're also removing the details of activation sizes here, and using the same arrow colors as the previous diagram):"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img alt=\"Basic recurrent neural network\" width=\"400\" caption=\"Basic recurrent neural network\" id=\"basic_rnn\" src=\"images/att_00070.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You will see that there is a set of activations which are being updated each time through the loop, and are stored in the variable `h` — this is called the *hidden state*."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Jargon: hidden state: the activations that are updated at each step of a recurrent neural network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A neural network which is defined using a loop like this is called a *recurrent neural network*, also known as an RNN. It is important to realise that an RNN is not a complicated new architecture, but is simply a refactoring of a multilayer neural network using a for loop.\n",
"\n",
"> A: My true opinion: if they were called \"looping neural networks\", or LNNs, they would seem 50% less daunting!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Improving the RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Maintaining the state of an RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the code for our RNN, one thing that seems problematic is that we are initialising our hidden state to zero for every new input sequence. Why is that a problem? We made our sample sequences short so they would fit easily into batches. But if we order those samples correctly, those sample sequences will be read in order by the model, exposing the model to long stretches of the original sequence. \n",
"\n",
"But because we initialize the model's hidden state to zero for each new sample, we are throwing away all the information we have about the sentences we have seen so far, which means that our model doesn't actually know where we are up to in the overall counting sequence. This is easily fixed; we can simply move the initialisation of the hidden state to `__init__`.\n",
"\n",
"But this fix will create its own subtle, but important, problem. It effectively makes our neural network as deep as the entire number of tokens in our document. For instance, if there were 10,000 tokens in our dataset, we would be creating a 10,000 layer neural network.\n",
"\n",
"To see this, consider the original pictorial representation of our recurrent neural network, before refactoring it with a for loop. You can see each layer corresponds with one token input. When we talk about the representation of a recurrent neural network before refactoring with the for loop, we call this the *unrolled representation*. It is often helpful to consider the unrolled representation when trying to understand an RNN.\n",
"\n",
"The problem with a 10,000 layer neural network is that if and when you get to the 10,000th word of the dataset, you will still need to calculate the derivatives all the way back to the first layer. This is going to be very slow indeed, and very memory intensive. It is unlikely that you could store even one mini batch on your GPU.\n",
"\n",
"The solution to this is to tell PyTorch that we do not want to back propagate the derivatives through the entire implicit neural network. Instead, we will just keep the last three layers of gradients. To remove all of the gradient history in PyTorch, we use the `detach` method.\n",
"\n",
"Here is the new version of our RNN. It is now stateful, because it remembers its activations between different calls to `forward`, which represent its use for different samples in the batch:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel3(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" for i in range(3):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" out = self.h_o(self.h)\n",
" self.h = self.h.detach()\n",
" return out\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you think about it, this model will have the same activations whatever the sequence length we pick, because the hidden state will remember the last activation from the previous batch. The only thing that will be different are the gradients computed at each step: they will only be calculated on sequence length tokens in the past, instead of the whole stream. That is why this sequence length is often called *bptt* for back-propagation through time."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* jargon: Back propagation through time (BPTT): Treating a neural net with effectively one layer per time step (usually refactored using a loop) as one big model, and calculating gradients on it in the usual way. To avoid running out of memory and time, we usually use _truncated_ BPTT, which \"detaches\" the history of computation steps in the hidden state every few time steps."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use `LMModel3`, we need to make sure the samples are going to be seen in a certain order. As we saw in the previous chapter, if the first line of the first batch is our `dset[0]` then the second batch should have `dset[1]` as the first line, so that the model sees the text flowing.\n",
"\n",
"`LMDataLoader` was doing this for us in the previous chapter. This time we're going to do it ourselves.\n",
"\n",
"To do this, we are going to rearrange our dataset. First we divide the samples into `m = len(dset) // bs` groups (this is the equivalent of splitting the whole concatenated dataset into, for instance, 64 equally sized pieces, since we're using `bs=64` here). `m` is the length of each of these pieces. For instance, if we're using our whole dataset (although we'll actually split it into train vs valid in a moment), that will be:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(328, 64, 21031)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = len(seqs)//bs\n",
"m,bs,len(seqs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first batch will be composed of the samples:\n",
"\n",
" (0, m, 2*m, ..., (bs-1)*m)\n",
"\n",
"then the second batch of the samples: \n",
"\n",
" (1, m+1, 2*m+1, ..., (bs-1)*m+1)\n",
"\n",
"and so forth. This way, at each epoch, the model will see a chunk of contiguous text of size `3*m` (since each text is of size 3) on each line of the batch.\n",
"\n",
"The following function does that reindexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def group_chunks(ds, bs):\n",
" m = len(ds) // bs\n",
" new_ds = L()\n",
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
" return new_ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we just pass `drop_last=True` when building our `DataLoaders` to drop the last batch that has not a shape of `bs`, we also pass `shuffle=False` to make sure the texts are read in order."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The last thing we add is a little tweak of the training loop via a `Callback`. We will talk more about callbacks in <<chapter_callbacks>>; this one will call the `reset` method of our model at the beginning of each epoch and before each validation phase. Since we implemented that method to zero the hidden state of the model, this will make sure we start we a clean state before reading those continuous chunks of text. We can also start training a bit longer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.677074</td>\n",
" <td>1.827367</td>\n",
" <td>0.467548</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.282722</td>\n",
" <td>1.870913</td>\n",
" <td>0.388942</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.090705</td>\n",
" <td>1.651793</td>\n",
" <td>0.462500</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.005092</td>\n",
" <td>1.613794</td>\n",
" <td>0.516587</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.965975</td>\n",
" <td>1.560775</td>\n",
" <td>0.551202</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.916182</td>\n",
" <td>1.595857</td>\n",
" <td>0.560577</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.897657</td>\n",
" <td>1.539733</td>\n",
" <td>0.574279</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.836274</td>\n",
" <td>1.585141</td>\n",
" <td>0.583173</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.805877</td>\n",
" <td>1.629808</td>\n",
" <td>0.586779</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.795096</td>\n",
" <td>1.651267</td>\n",
" <td>0.588942</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(10, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating more signal"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another problem with our current approach is that we only predict one output word for each three input words. That means that the amount of signal that we are feeding back to update weights with is not as large as it could be. It would be better if we predicted the next word after every single word, rather than every three words. Here's the pictorial version:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img alt=\"RNN predicting after every token\" width=\"400\" caption=\"RNN predicting after every token\" id=\"stateful_rep\" src=\"images/att_00024.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is easy enough to add. We need to first change our data so that the dependent variable has each of the three next words after each of our three input words. Instead of 3, we use an attribute, `sl` (for sequence length) and make it a bit bigger:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sl = 16\n",
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
" for i in range(0,len(nums)-sl-1,sl))\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
" group_chunks(seqs[cut:], bs),\n",
" bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the first element of `seqs`, we can see that it contains two lists of the same size. The second list is the same as the first, but offset by one element:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n",
" (#16) ['.','two','.','three','.','four','.','five','.','six'...]]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[L(vocab[o] for o in s) for s in seqs[0]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we need to modify our model so that it outputs a prediction after every word, rather than just at the end of a three word sequence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel4(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" outs = []\n",
" for i in range(sl):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" outs.append(self.h_o(self.h))\n",
" self.h = self.h.detach()\n",
" return torch.stack(outs, dim=1)\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model will return outputs of shape `bs x sl x vocab_sz` (since we stacked on `dim=1`). Our targets are of shape `bs x sl`, so we need to flatten those before using them in `F.cross_entropy`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss_func(inp, targ):\n",
" return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now use this loss function to train the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>3.103298</td>\n",
" <td>2.874341</td>\n",
" <td>0.212565</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.231964</td>\n",
" <td>1.971280</td>\n",
" <td>0.462158</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.711358</td>\n",
" <td>1.813547</td>\n",
" <td>0.461182</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.448516</td>\n",
" <td>1.828176</td>\n",
" <td>0.483236</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.288630</td>\n",
" <td>1.659564</td>\n",
" <td>0.520671</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1.161470</td>\n",
" <td>1.714023</td>\n",
" <td>0.554932</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1.055568</td>\n",
" <td>1.660916</td>\n",
" <td>0.575033</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.960765</td>\n",
" <td>1.719624</td>\n",
" <td>0.591064</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.870153</td>\n",
" <td>1.839560</td>\n",
" <td>0.614665</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.808545</td>\n",
" <td>1.770278</td>\n",
" <td>0.624349</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.758084</td>\n",
" <td>1.842931</td>\n",
" <td>0.610758</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>0.719320</td>\n",
" <td>1.799527</td>\n",
" <td>0.646566</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>0.683439</td>\n",
" <td>1.917928</td>\n",
" <td>0.649821</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>0.660283</td>\n",
" <td>1.874712</td>\n",
" <td>0.628581</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>0.646154</td>\n",
" <td>1.877519</td>\n",
" <td>0.640055</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(15, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to train for longer, since the task has changed a bit and is more complicated now. But we end up with a good result... At least, sometimes. If you run it a few times, you'll see that you can get quite different results on different runs. That's because effectively we have a very deep network here, which can result in very large or very small gradients. We'll see in the next chapter how to resolve this, by using the `LSTM` architecture.\n",
"\n",
"We can also see that `valid_loss` is getting worse, so it may help to add some additional regularization. That will be provided by the `AWD` variant of `LSTM`, which we'll also see in the next chapter.\n",
"\n",
"By combining these techniques, we'll see how to get around 85% accuracy on this dataset!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. If the dataset for your project is so big and complicated that working with it takes a significant amount of time, what should you do?\n",
"1. Why do we concatenating the documents in our dataset before creating a language model?\n",
"1. To use a standard fully connected network to predict the fourth word given the previous three words, what two tweaks do we need to make?\n",
"1. How can we share a weight matrix across multiple layers in PyTorch?\n",
"1. Write a module which predicts the third word given the previous two words of a sentence, without peeking.\n",
"1. What is a recurrent neural network?\n",
"1. What is hidden state?\n",
"1. What is the equivalent of hidden state in ` LMModel1`?\n",
"1. To maintain the state in an RNN why is it important to pass the text to the model in order?\n",
"1. What is an unrolled representation of an RNN?\n",
"1. Why can maintaining the hidden state in an RNN lead to memory and performance problems? How do we fix this problem?\n",
"1. What is BPTT?\n",
"1. Write code to print out the first few batches of the validation set, including converting the token IDs back into English strings, as we showed for batches of IMDb data in the previous chapter.\n",
"1. What does the `ModelReseter` callback do? Why do we need it?\n",
"1. What are the downsides of predicting just one output word for each three input words?\n",
"1. Why do we need a custom loss function for `LMModel4`?\n",
"1. Why is the training of `LMModel4` unstable?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. In ` LMModel2` why can `forward` start with `h=0`? Why don't we need to say `h=torch.zeros(…)`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}