mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 01:40:44 +00:00
1158 lines
46 KiB
Plaintext
1158 lines
46 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#hide\n",
|
||
"from utils import *"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"hide_input": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#hide\n",
|
||
"from fastai2.text.all import *\n",
|
||
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
|
||
"lines = L()\n",
|
||
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
|
||
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
|
||
"text = ' . '.join([l.strip() for l in lines])\n",
|
||
"tokens = text.split(' ')\n",
|
||
"vocab = L(*tokens).unique()\n",
|
||
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
|
||
"nums = L(word2idx[i] for i in tokens)\n",
|
||
"\n",
|
||
"def group_chunks(ds, bs):\n",
|
||
" m = len(ds) // bs\n",
|
||
" new_ds = L()\n",
|
||
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
|
||
" return new_ds"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "raw",
|
||
"metadata": {},
|
||
"source": [
|
||
"[[chapter_better_rnn]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Making our RNN state of the art"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We saw in the last chapter how to build a basic RNN from scratch. Now we will see how to make it better up until the AWD LSTM architecture we used in <<chapter_nlp>> on this text classification problem.\n",
|
||
"\n",
|
||
"We won't go other the whole data preparation process again. To make the comparison fair against our last example, we use the same batch size and sequence length:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sl,bs = 16,64\n",
|
||
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
|
||
" for i in range(0,len(nums)-sl-1,sl))\n",
|
||
"cut = int(len(seqs) * 0.8)\n",
|
||
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
|
||
" group_chunks(seqs[cut:], bs),\n",
|
||
" bs=bs, drop_last=True, shuffle=False)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The obvious way to get a better model is to go deeper: we only have one linear layer between the hidden state and the output activations in our basic RNN, so maybe we would get better results with more."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Multilayer RNNs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### The model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In a multilayer RNN, we pass the activations from our recurrent neural network into a second recurrent neural network, like so:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img alt=\"2-layer RNN\" width=\"550\" caption=\"2-layer RNN\" id=\"stacked_rnn_rep\" src=\"images/att_00025.png\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"…or in an unrolled representation:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img alt=\"2-layer unrolled RNN\" width=\"500\" caption=\"2-layer unrolled RNN\" id=\"unrolled_stack_rep\" src=\"images/att_00026.png\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let's save some time by using PyTorch's RNN class, which implements exactly what we have created above, but also gives us the option to stack multiple RNNs, as we have discussed:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LMModel5(Module):\n",
|
||
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
||
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
||
" self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||
" self.h = torch.zeros(n_layers, bs, n_hidden)\n",
|
||
" \n",
|
||
" def forward(self, x):\n",
|
||
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
||
" self.h = h.detach()\n",
|
||
" return self.h_o(res)\n",
|
||
" \n",
|
||
" def reset(self): self.h.zero_()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: left;\">\n",
|
||
" <th>epoch</th>\n",
|
||
" <th>train_loss</th>\n",
|
||
" <th>valid_loss</th>\n",
|
||
" <th>accuracy</th>\n",
|
||
" <th>time</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.055853</td>\n",
|
||
" <td>2.591640</td>\n",
|
||
" <td>0.437907</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>2.162359</td>\n",
|
||
" <td>1.787310</td>\n",
|
||
" <td>0.471598</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>1.710663</td>\n",
|
||
" <td>1.941807</td>\n",
|
||
" <td>0.321777</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1.520783</td>\n",
|
||
" <td>1.999726</td>\n",
|
||
" <td>0.312012</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1.330846</td>\n",
|
||
" <td>2.012902</td>\n",
|
||
" <td>0.413249</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>5</td>\n",
|
||
" <td>1.163297</td>\n",
|
||
" <td>1.896192</td>\n",
|
||
" <td>0.450684</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>6</td>\n",
|
||
" <td>1.033813</td>\n",
|
||
" <td>2.005209</td>\n",
|
||
" <td>0.434814</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>7</td>\n",
|
||
" <td>0.919090</td>\n",
|
||
" <td>2.047083</td>\n",
|
||
" <td>0.456706</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>8</td>\n",
|
||
" <td>0.822939</td>\n",
|
||
" <td>2.068031</td>\n",
|
||
" <td>0.468831</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>9</td>\n",
|
||
" <td>0.750180</td>\n",
|
||
" <td>2.136064</td>\n",
|
||
" <td>0.475098</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0.695120</td>\n",
|
||
" <td>2.139140</td>\n",
|
||
" <td>0.485433</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>11</td>\n",
|
||
" <td>0.655752</td>\n",
|
||
" <td>2.155081</td>\n",
|
||
" <td>0.493652</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>12</td>\n",
|
||
" <td>0.629650</td>\n",
|
||
" <td>2.162583</td>\n",
|
||
" <td>0.498535</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>13</td>\n",
|
||
" <td>0.613583</td>\n",
|
||
" <td>2.171649</td>\n",
|
||
" <td>0.491048</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>14</td>\n",
|
||
" <td>0.604309</td>\n",
|
||
" <td>2.180355</td>\n",
|
||
" <td>0.487874</td>\n",
|
||
" <td>00:01</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"<IPython.core.display.HTML object>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"learn = Learner(dls, LMModel5(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)\n",
|
||
"learn.fit_one_cycle(15, 3e-3)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now that's disappointing... we are doing more poorly than the single-layer RNN from the end of last chapter. The reason is that we have a deeper model, leading to exploding or disappearing activations."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Exploding or disappearing activations"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In practice, creating accurate models from this kind of RNN is difficult. We will get better results if we call `detach` less often, and have more layers — this gives our RNN a longer time horizon to learn from, and richer features to create. But it also means we have a deeper model to train. The key challenge in the development of deep learning has been figuring out how to train these kinds of models.\n",
|
||
"\n",
|
||
"The reason this is challenging is because of what happens when you multiply by a matrix many times. Think about what happens when you multiply by a number many times. For example, if you multiply by two, starting at one, you get the sequence 1, 2, 4, 8,… after 32 steps you are already at 4,294,967,296. A similar issue happens if we multiply by 0.5: we get 0.5, 0.25, 0.125… and after 32 steps it's 0.00000000023. As you can see, a number even slightly higher or lower than one results in an explosion or disappearance of our number, after just a few repeated multiplications.\n",
|
||
"\n",
|
||
"Because matrix multiplication is just multiplying numbers and adding them up, exactly the same thing happens with repeated matrix multiplications. And a deep neural network is just repeated matrix multiplications--each extra layer is another matrix multiplication. This means that it is very easy for a deep neural network to end up with extremely large, or extremely small numbers.\n",
|
||
"\n",
|
||
"This is a problem, because the way computers store numbers (known as \"floating point\") means that they become less and less accurate the further away the numbers get from zero. This diagram, from the excellent article [What you never wanted to know about floating point but will be forced to find out](http://www.volkerschatz.com/science/float.html), shows how the precision of floating point numbers varies over the number line:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img alt=\"Precision of floating point numbers\" width=\"1000\" caption=\"Precision of floating point numbers\" id=\"float_prec\" src=\"images/fltscale.svg\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This inaccuracy means that often the gradients calculated for updating the weights end up as zero or infinity for deep networks. This is commonly refered to as *vanishing gradients* or *exploding gradients*. That means that in SGD, the weights are updated either not at all, or jump to infinity. Either way, they won't improve with training.\n",
|
||
"\n",
|
||
"Researchers have developed a number of ways to tackle this problem, which we will be discussing later in the book. One way to tackle the problem is to change the definition of a layer in a way that makes it less likely to have exploding activations. We'll look at the details of how this is done in <<chapter_deep_conv>>, when we discuss *batch normalization*, and <<chapter_resnet>>, when we discuss *ResNets*, although these details don't generally matter in practice (unless you are a researcher that is creating new approaches to solving this problem). Another way to deal with this is by being careful about *initialization*, which is a topic we'll investigate in <<chapter_foundations>>.\n",
|
||
"\n",
|
||
"For RNNs, there are two types of layers frequently used to avoid exploding activations, and they are: *gated recurrent units* (GRU), and *Long Short-Term Memory* (LSTM). Both of these are available in PyTorch, and are drop-in replacements for the RNN layer. We will only cover LSTMs in this book, there are plenty of good tutorials online explaining GRUs, which are a minor variant on the LSTM design."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## LSTM"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"LSTM (for long short-term memory) is an architecture that was introduced back in 1997 by Jurgen Schmidhuber and Sepp Hochreiter. In this architecture, there are not one but two hidden states. In our base RNN, the hidden state is the output of the RNN at the previous time step. That hidden state is then responsible for doing two things at a time:\n",
|
||
"\n",
|
||
"- having the right information for the output layer to predict the correct next token\n",
|
||
"- retaining memory of everything that happened in the sentence\n",
|
||
"\n",
|
||
"Consider, for example, the sentences \"Henry has a dog and he likes his dog very much\" and \"Sophie has a dog and she likes her dog very much\". It's very clear that the RNN needs to remember the name at the beginning of the sentence to be able to predict *he/she* or *his/her*. \n",
|
||
"\n",
|
||
"In practice, RNNs are really bad at retaining memory of what happened much earlier in the sentence, which is the motivation to have another hidden state (called cell state) in the LSTM. The cell state will be responsible for keeping *long short-term memory*, while the hidden state will focus on the next token to predict. Let's have a closer look and how this is achieved and build one LSTM from scratch."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Building an LSTM from scratch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The schematic of an LSTM is given like so:\n",
|
||
"\n",
|
||
"<img src=\"images/LSTM.png\" id=\"lstm\" caption=\"Architecture of an LSTM\" alt=\"A graph showing the inner architecture of an LSTM\" width=\"700\">"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In this picture, our input $x_{t}$ enters on the bottom with the previous hidden state ($h_{t-1}$) and cell state ($x_{t-1}$). The four orange boxes represent four layers with the activation being either sigmoid (for $\\sigma$) or tanh. tanh is just a sigmoid rescaled to the range -1 to 1. Its mathematical expression can be written like this:\n",
|
||
"\n",
|
||
"$$\\tanh(x) = \\frac{e^{x} + e^{-x}}{e^{x}-e^{-x}} = 2 \\sigma(2x) - 1$$\n",
|
||
"\n",
|
||
"where $\\sigma$ is the sigmoid function. The green boxes are elementwise operations. What goes out is the new hidden state ($h_{t}$) and new cell state ($c_{t}$) on the left, ready for our next input. The new hidden state is also use as output, which is why the arrow splits to go up.\n",
|
||
"\n",
|
||
"Let's go over the four neural nets (called *gates*) one by one and explain the diagram, but before this, notice how very little the cell state (on the top) is changed. It doesn't even go directly through a neural net! This is exactly why it will carry on a longer-term state.\n",
|
||
"\n",
|
||
"First, the arrows for input and old hidden state are joined together. In the RNN we wrote in the past chapter, we were adding them together. In the LSTM, we stack them in one big tensor. This means the dimension of our embeddings (which is the dimension of $x_{t}$) can be different than the dimension of our hidden state. If we call those `n_in` and `n_hid`, the arrow at the bottom is of size `n_in + n_hid`, thus all the neural nets (orange boxes) are linear layers with `n_in + n_hid` inputs and `n_hid` outputs.\n",
|
||
"\n",
|
||
"The first gate (looking from the left to right) is called the *forget gate*. Since it's a linear layer followed by a sigmoid, its output will have scalars between 0 and 1. We multiply this result y the cell gate, so for all the values close to 0, we will forget what was inside that cell state (and for the values close to 1 it doesn't do anything). This gives the ability to the LSTM to forget things about its longterm state. For instance, when crossing a period or an `xxbos` token, we would expect to it to (have learned to) reset its cell state.\n",
|
||
"\n",
|
||
"The second gate works is called the *input gate*. It works with the third gate (which doesn't really have a name but is sometimes called the *cell gate*) to update the cell state. For instance we may see a new gender pronoum, so we must replace the information about gender that the forget gate removed by the new one. Like the forget gate, the input gate ends up on a product, so it jsut decides which element of the cell state to update (valeus close to 1) or not (values close to 0). The third gate will then fill those values with things between -1 and 1 (thanks to the tanh). The result is then added to the cell state.\n",
|
||
"\n",
|
||
"The last gate is the *output gate*. It will decides which information take in the cell state to generate the output. The cell state goes through a tanh before this and the output gate combined with the sigmoid decides which values to take inside it.\n",
|
||
"\n",
|
||
"In terms of code, we can write the same steps like this:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LSTMCell(Module):\n",
|
||
" def __init__(self, ni, nh):\n",
|
||
" self.forget_gate = nn.Linear(ni + nh, nh)\n",
|
||
" self.input_gate = nn.Linear(ni + nh, nh)\n",
|
||
" self.cell_gate = nn.Linear(ni + nh, nh)\n",
|
||
" self.output_gate = nn.Linear(ni + nh, nh)\n",
|
||
"\n",
|
||
" def forward(self, input, state):\n",
|
||
" h,c = state\n",
|
||
" h = torch.stack([h, input], dim=1)\n",
|
||
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
||
" c = c * forget\n",
|
||
" inp = torch.sigmoid(self.input_gate(h))\n",
|
||
" cell = torch.tanh(self.cell_gate(h))\n",
|
||
" c = c + inp * cell\n",
|
||
" out = torch.sigmoid(self.output_gate(h))\n",
|
||
" h = outgate * torch.tanh(c)\n",
|
||
" return h, (h,c)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In practice, we can then refactor the code. Also, in terms of performance, it's better to do one big matrix multiplication than four smaller ones (that's because we only launch the special fast kernel on GPU once, and it gives the GPU more work to do in parallel). The stacking takes a bit of time (since we have to move one of the tensors around on the GPU to have it all in a contiguous array), so we use two separate layers for the input and the hidden state. The optimized and refactored code then looks like that:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LSTMCell(Module):\n",
|
||
" def __init__(self, ni, nh):\n",
|
||
" self.ih = nn.Linear(ni,4*nh)\n",
|
||
" self.hh = nn.Linear(nh,4*nh)\n",
|
||
"\n",
|
||
" def forward(self, input, state):\n",
|
||
" h,c = state\n",
|
||
" #One big multiplication for all the gates is better than 4 smaller ones\n",
|
||
" gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
|
||
" ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
|
||
" cellgate = gates[3].tanh()\n",
|
||
"\n",
|
||
" c = (forgetgate*c) + (ingate*cellgate)\n",
|
||
" h = outgate * c.tanh()\n",
|
||
" return h, (h,c)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Here we use the PyTorch `chunk` method to split our tensor into 4 pieces, e.g.:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
|
||
]
|
||
},
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"t = torch.arange(0,10); t"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"t.chunk(2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Training a language model using LSTMs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Here is the same network as before, using a two-layer LSTM. We can train it at a higher learning rate, for a shorter time, and get better accuracy:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LMModel6(Module):\n",
|
||
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
||
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
||
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||
" \n",
|
||
" def forward(self, x):\n",
|
||
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
||
" self.h = [h_.detach() for h_ in h]\n",
|
||
" return self.h_o(res)\n",
|
||
" \n",
|
||
" def reset(self): \n",
|
||
" for h in self.h: h.zero_()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: left;\">\n",
|
||
" <th>epoch</th>\n",
|
||
" <th>train_loss</th>\n",
|
||
" <th>valid_loss</th>\n",
|
||
" <th>accuracy</th>\n",
|
||
" <th>time</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3.000821</td>\n",
|
||
" <td>2.663942</td>\n",
|
||
" <td>0.438314</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>2.139642</td>\n",
|
||
" <td>2.184780</td>\n",
|
||
" <td>0.240479</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>1.607275</td>\n",
|
||
" <td>1.812682</td>\n",
|
||
" <td>0.439779</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1.347711</td>\n",
|
||
" <td>1.830982</td>\n",
|
||
" <td>0.497477</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1.123113</td>\n",
|
||
" <td>1.937766</td>\n",
|
||
" <td>0.594401</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>5</td>\n",
|
||
" <td>0.852042</td>\n",
|
||
" <td>2.012127</td>\n",
|
||
" <td>0.631592</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>6</td>\n",
|
||
" <td>0.565494</td>\n",
|
||
" <td>1.312742</td>\n",
|
||
" <td>0.725749</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>7</td>\n",
|
||
" <td>0.347445</td>\n",
|
||
" <td>1.297934</td>\n",
|
||
" <td>0.711263</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>8</td>\n",
|
||
" <td>0.208191</td>\n",
|
||
" <td>1.441269</td>\n",
|
||
" <td>0.731201</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>9</td>\n",
|
||
" <td>0.126335</td>\n",
|
||
" <td>1.569952</td>\n",
|
||
" <td>0.737305</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0.079761</td>\n",
|
||
" <td>1.427187</td>\n",
|
||
" <td>0.754150</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>11</td>\n",
|
||
" <td>0.052990</td>\n",
|
||
" <td>1.494990</td>\n",
|
||
" <td>0.745117</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>12</td>\n",
|
||
" <td>0.039008</td>\n",
|
||
" <td>1.393731</td>\n",
|
||
" <td>0.757894</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>13</td>\n",
|
||
" <td>0.031502</td>\n",
|
||
" <td>1.373210</td>\n",
|
||
" <td>0.758464</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>14</td>\n",
|
||
" <td>0.028068</td>\n",
|
||
" <td>1.368083</td>\n",
|
||
" <td>0.758464</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"<IPython.core.display.HTML object>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"learn = Learner(dls, LMModel6(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)\n",
|
||
"learn.fit_one_cycle(15, 1e-2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now that's better than a multilayer RNN! We can still see there is a bit of overfitting, which is a sign that a bit of regularization might help."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Regularizing an LSTM"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Recurrent neural networks, in general, are hard to train. Using LSTMs (or GRUs) cell make training easier than vanilla RNNs, but there are still very prone to overfitting. Data augmentation, while it exists for text data, is less often used because in most cases, it requires another model to generate random augmentation (by translating in another language and back to the language used for instance). Overall, data augmentation for text data is currently not a well explored space.\n",
|
||
"\n",
|
||
"However, there are other regularization techniques we can use instead, which were thoroughly studied for use with LSTMs in the paper [Regularizing and Optimizing LSTM Language Models](https://arxiv.org/abs/1708.02182). This paper showed how effective use of *dropout*, *activation regularization*, and *temporal activation regularization* could allow an LSTM to beat state of the art results that previously required much more complicated models. They called an LSTM using these techniques an *AWD LSTM*. We'll look at each of these techniques in turn."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Dropout"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Dropout is a regularization technique that was introduce by Geoffrey Hinton et al. in [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580). The basic idea is to randomly change some activations to zero at training time. This makes sure all neurons actively work toward the output as seen in this figure from the original paper:\n",
|
||
"\n",
|
||
"<img src=\"images/Dropout1.png\" alt=\"A figure from the article showing how neurons go off with dropout\" width=\"800\">\n",
|
||
"\n",
|
||
"Hinton used a nice metaphor when he explained, in an interview, the inspiration for dropout:\n",
|
||
"\n",
|
||
"> : \"I went to my bank. The tellers kept changing and I asked one of them why. He said he didn’t know but they got moved around a lot. I figured it must be because it would require cooperation between employees to successfully defraud the bank. This made me realize that randomly removing a different subset of neurons on each example would prevent conspiracies and thus reduce overfitting\"\n",
|
||
"\n",
|
||
"In the same interview, he also explained that neuroscience provided additional inspiration:\n",
|
||
"\n",
|
||
"> : \"We don't really know why neurons spike. One theory is that they want to be noisy so as to regularize, because we have many more parameters than we have data points. The idea of dropout is that if you have noisy activations, you can afford to use a much bigger model.\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can see there that if we just zero those activations without doing anything else, our model will have problems to train: if we go from the sum of 5 activations (that are all positive numbers since we apply a ReLU) to just 2, this won't have the same scale. Therefore if we dropout with a probability `p`, we rescale all activation by dividing them by `1-p` (on average `p` will be zeroed, so it leaves `1-p`), as shown in this diagram from the original paper:\n",
|
||
"\n",
|
||
"<img src=\"images/Dropout.png\" alt=\"A figure from the article introducing dropout showing how a neuron is on/off\" width=\"600\">\n",
|
||
"\n",
|
||
"This is a full implementation of the dropout layer in PyTorch (although PyTorch's native layer is actually written in C, not Python):"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Dropout(Module):\n",
|
||
" def __init__(self, p): self.p = p\n",
|
||
" def forward(self, x):\n",
|
||
" if not self.training: return x\n",
|
||
" mask = x.new(*x.shape).bernoulli_(1-p)\n",
|
||
" return x * mask.div_(1-p)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The `bernoulli_` method is creating a tensor with random zeros (with probability p) and ones (with probability 1-p), which is then multiplied with our input before dividing by `1-p`. Note the use of the `training` attribute, which is available in any PyTorch `nn.Module`, and tells us if we are doing training or inference.\n",
|
||
"\n",
|
||
"> note: In previous chapters of the book we'd be adding a code example for `bernoulli_` here, so you can see exactly how it works. But now that you know enough to do this yourself, we're going to be doing fewer and fewer examples for you, and instead expecting you to do your own experiments to see how things work. In this case, you'll see in the end-of-chapter questionnaire that we're asking you to experiment with `bernoulli_`--but don't wait for us to ask you to experiment to develop your understanding of the code we're studying, go ahead and do it anyway!\n",
|
||
"\n",
|
||
"Using dropout before passing the output of our LSTM to the final layer will help reduce overfitting. Dropout is also used in many other models, including the default CNN head used in `fastai.vision`, and is also available in `fastai.tabular` by passing the `ps` parameter (where each \"p\" is passed to each added `Dropout` layer), as we'll see in <<chapter_arch_details>>."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Dropout has a different behavior in training and validation mode, which we achieved using the `training` attribute in `Dropout` above. Calling the `train()` method on a `Module` sets `training` to `True` (both for the module you call the method on, and for every module it recursively contains), and `eval()` sets it to `False`. This is done automatically when calling the methods of `Learner`, but if you are not using that class, remember to switch from one to the other as needed."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### AR and TAR regularization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"AR (for *activation regularization*) and TAR (for *temporal activation regularization*) are two regularization methods very similar to weight decay. When applying weight decay, we add a small penalty to the loss that aims at making the weights as small as possible. For the activation regularization, it's the final activations produced by the LSTM that we will try to make as small as possible, instead of the weights.\n",
|
||
"\n",
|
||
"To regularize the final activations, we have to store those somewhere, then add the means of the squares of them to the loss (along with a multiplier `alpha`, which is just like `wd` for weight decay):\n",
|
||
"\n",
|
||
"``` python\n",
|
||
"loss += alpha * activations.pow(2).mean()\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Temporal activation regularization is linked to the fact we are predicting tokens in a sentence. That means it's likely that the outputs of our LSTMs should somewhat make sense when we read them in order. TAR is there to encourage that behavior by adding a penalty to the loss to make the difference between two consecutive activations as small as possible: our activations tensor has a shape `bs x sl x n_hid`, and we read consecutive activation on the sequence length axis (so the dimension in the middle). With this, TAR can be expressed as:\n",
|
||
"\n",
|
||
"``` python\n",
|
||
"loss += beta * (activations[:,1:] - activations[:,:-1]).pow(2).mean()\n",
|
||
"```\n",
|
||
"\n",
|
||
"`alpha` and `beta` are then two hyper-parameters to tune. To make this work, we need our model with dropout to return three things: the proper output, the activations of the LSTM pre-dropout and the activations of the LSTM post-dropout. AR is often applied on the dropped out activations (to not penalize the activations we turned in 0s afterward) while TAR is applied on the non-dropped out activations (because those 0s create big differences between two consecutive timesteps). There is then a callback called `RNNRegularizer` that will apply this regularization for us."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Training a weight-tied regularized LSTM"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can combine dropout (applied before we go in our output layer) with the AR and TAR regularization to train our previous LSTM. We just need to return three things instead of one: the normal output of our LSTM, the dropped-out activations and the activations from our LSTMs. Those last two will be picked up by the callback `RNNRegularization` for the contributions it has to make to the loss.\n",
|
||
"\n",
|
||
"Another useful trick we can add from the AWD LSTM paper is *weight tying*. In a language model, the input embeddings represent a mapping from English words to activations, and the output hidden layer represents a mapping from activations to English words. We might expect, intuitively, that these mappings could be the same. We can represent this in PyTorch by assigning the same weight matrix to each of these layers:\n",
|
||
"\n",
|
||
" self.h_o.weight = self.i_h.weight\n",
|
||
"\n",
|
||
"In `LMMModel7`, we include these final tweaks:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 48,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class LMModel7(Module):\n",
|
||
" def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
|
||
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
||
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
||
" self.drop = nn.Dropout(p)\n",
|
||
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
||
" self.h_o.weight = self.i_h.weight\n",
|
||
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
||
" \n",
|
||
" def forward(self, x):\n",
|
||
" raw,h = self.rnn(self.i_h(x), self.h)\n",
|
||
" out = self.drop(raw)\n",
|
||
" self.h = [h_.detach() for h_ in h]\n",
|
||
" return self.h_o(out),raw,out\n",
|
||
" \n",
|
||
" def reset(self): \n",
|
||
" for h in self.h: h.zero_()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can create a regularized `Learner` using the `RNNRegularizer` callback:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n",
|
||
" loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n",
|
||
" cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"A `TextLearner` automatically adds those two callbacks for us (with default for `alpha` and `beta` as above) so we can simplify the line above to:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n",
|
||
" loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can the train the model, and add additional regularization by increasing the weight decay to `0.1`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 54,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: left;\">\n",
|
||
" <th>epoch</th>\n",
|
||
" <th>train_loss</th>\n",
|
||
" <th>valid_loss</th>\n",
|
||
" <th>accuracy</th>\n",
|
||
" <th>time</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>2.693885</td>\n",
|
||
" <td>2.013484</td>\n",
|
||
" <td>0.466634</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1.685549</td>\n",
|
||
" <td>1.187310</td>\n",
|
||
" <td>0.629313</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.973307</td>\n",
|
||
" <td>0.791398</td>\n",
|
||
" <td>0.745605</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>0.555823</td>\n",
|
||
" <td>0.640412</td>\n",
|
||
" <td>0.794108</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>0.351802</td>\n",
|
||
" <td>0.557247</td>\n",
|
||
" <td>0.836100</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>5</td>\n",
|
||
" <td>0.244986</td>\n",
|
||
" <td>0.594977</td>\n",
|
||
" <td>0.807292</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>6</td>\n",
|
||
" <td>0.192231</td>\n",
|
||
" <td>0.511690</td>\n",
|
||
" <td>0.846761</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>7</td>\n",
|
||
" <td>0.162456</td>\n",
|
||
" <td>0.520370</td>\n",
|
||
" <td>0.858073</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>8</td>\n",
|
||
" <td>0.142664</td>\n",
|
||
" <td>0.525918</td>\n",
|
||
" <td>0.842285</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>9</td>\n",
|
||
" <td>0.128493</td>\n",
|
||
" <td>0.495029</td>\n",
|
||
" <td>0.858073</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0.117589</td>\n",
|
||
" <td>0.464236</td>\n",
|
||
" <td>0.867188</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>11</td>\n",
|
||
" <td>0.109808</td>\n",
|
||
" <td>0.466550</td>\n",
|
||
" <td>0.869303</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>12</td>\n",
|
||
" <td>0.104216</td>\n",
|
||
" <td>0.455151</td>\n",
|
||
" <td>0.871826</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>13</td>\n",
|
||
" <td>0.100271</td>\n",
|
||
" <td>0.452659</td>\n",
|
||
" <td>0.873617</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>14</td>\n",
|
||
" <td>0.098121</td>\n",
|
||
" <td>0.458372</td>\n",
|
||
" <td>0.869385</td>\n",
|
||
" <td>00:02</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>"
|
||
],
|
||
"text/plain": [
|
||
"<IPython.core.display.HTML object>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"learn.fit_one_cycle(15, 1e-2, wd=0.1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now this is far better than our previous model!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Conclusion"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"You have now seen everything that is inside the AWD-LSTM architecture we used in text classification in <<chapter_nlp>>. It uses dropouts in a lot more places:\n",
|
||
"\n",
|
||
"- embedding dropout (just after the embedding layer)\n",
|
||
"- input dropout (after the embedding layer)\n",
|
||
"- weight dropout (applied to the weights of the LSTM at each training step)\n",
|
||
"- hidden dropout (applied to the hidden state between two layers)\n",
|
||
"\n",
|
||
"which makes it even more regularized. Since fine-tuning those five dropout values (adding the dropout before the output layer) is complicated, so we have determined good defaults, and allow the magnitude of dropout to be tuned overall with the `drop_mult` parameter you saw (which is multiplied by each dropout).\n",
|
||
"\n",
|
||
"Another architecture that is very powerful, especially in \"sequence to sequence\" problems (that is, problems where the dependent variable is itself a variable length sequence, such as language translation), is the Transformers architecture. You can find it in an online bonus chapter on the book website."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Questionnaire"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"1. In the unrolled representation, we can see that a recurrent neural network actually has many layers. So why do we need to stack RNNs to get better results?\n",
|
||
"1. Draw a representation of a stacked (multilayer) RNN.\n",
|
||
"1. Why should we get better results in an RNN if we call `detach` less often? Why might this not happen in practice with a simple RNN?\n",
|
||
"1. Why can a deep network result in very large or very small activations? Why does this matter?\n",
|
||
"1. In a computer's floating point representation of numbers, which numbers are the most precise?\n",
|
||
"1. Why do vanishing gradients prevent training?\n",
|
||
"1. Why does it help to have two hidden states in the LSTM architecture? What is the purpose of each one?\n",
|
||
"1. What are these two states called in an LSTM?\n",
|
||
"1. What is tanh, and how is it related to sigmoid?\n",
|
||
"1. What is the purpose of this code in `LSTMCell`?: `h = torch.stack([h, input], dim=1)`\n",
|
||
"1. What does `chunk` to in PyTorch?\n",
|
||
"1. Study the refactored version of `LSTMCell` carefully to ensure you understand how and why it does the same thing as the non-refactored version.\n",
|
||
"1. Why can we use a higher learning rate for `LMModel6`?\n",
|
||
"1. What are the three regularisation techniques used in an AWD-LSTM model?\n",
|
||
"1. What is dropout?\n",
|
||
"1. Why do we scale the weights with dropout? Is this applied during training, inference, or both?\n",
|
||
"1. What is the purpose of this line from `Dropout`?: `if not self.training: return x`\n",
|
||
"1. Experiment with `bernoulli_` to understand how it works.\n",
|
||
"1. How do you set your model in training mode in PyTorch? In evaluation mode?\n",
|
||
"1. Write the equation for activation regularization (in maths or code, as you prefer). How is it different to weight decay?\n",
|
||
"1. Write the equation for temporal activation regularization (in maths or code, as you prefer). Why wouldn't we use this for computer vision problems?\n",
|
||
"1. What is \"weight tying\" in a language model?"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Further research"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"1. Write the code for an LSTM from scratch (but you may refer to <<lstm>>).\n",
|
||
"1. Search on the Internet for the GRU architecture and implement it from scratch, and try training a model. See if you can get the similar results as we saw in this chapter. Compare it to the results of PyTorch's built in GRU module.\n",
|
||
"1. Have a look at the source code for AWD-LSTM in fastai, and try to map each of the lines of code to the concepts shown in this chapter."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"jupytext": {
|
||
"split_at_heading": true
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.5"
|
||
},
|
||
"toc": {
|
||
"base_numbering": 1,
|
||
"nav_menu": {
|
||
"height": "245px",
|
||
"width": "258px"
|
||
},
|
||
"number_sections": false,
|
||
"sideBar": true,
|
||
"skip_h1_title": true,
|
||
"title_cell": "Table of Contents",
|
||
"title_sidebar": "Contents",
|
||
"toc_cell": false,
|
||
"toc_position": {},
|
||
"toc_section_display": true,
|
||
"toc_window_display": false
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|