mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-05 02:10:48 +00:00
1634 lines
45 KiB
Plaintext
1634 lines
45 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"from utils import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# A language model from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from fastai2.text.all import *\n",
|
|
"path = untar_data(URLs.HUMAN_NUMBERS)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"Path.BASE_PATH = path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#2) [Path('train.txt'),Path('valid.txt')]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"path.ls()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lines = L()\n",
|
|
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
|
|
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
|
|
"lines"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"text = ' . '.join([l.strip() for l in lines])\n",
|
|
"text[:100]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"tokens = text.split(' ')\n",
|
|
"tokens[:10]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"vocab = L(*tokens).unique()\n",
|
|
"vocab"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
|
|
"nums = L(word2idx[i] for i in tokens)\n",
|
|
"nums"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Our first language model from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
|
|
"seqs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"bs = 64\n",
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Our language model in PyTorch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel1(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
|
|
" h = h + self.i_h(x[:,1])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" h = h + self.i_h(x[:,2])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" return self.h_o(h)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.824297</td>\n",
|
|
" <td>1.970941</td>\n",
|
|
" <td>0.467554</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.386973</td>\n",
|
|
" <td>1.823242</td>\n",
|
|
" <td>0.467554</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.417556</td>\n",
|
|
" <td>1.654497</td>\n",
|
|
" <td>0.494414</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.376440</td>\n",
|
|
" <td>1.650849</td>\n",
|
|
" <td>0.494414</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, \n",
|
|
" metrics=accuracy)\n",
|
|
"learn.fit_one_cycle(4, 1e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(29), 'thousand', 0.15165200855716662)"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"n,counts = 0,torch.zeros(len(vocab))\n",
|
|
"for x,y in dls.valid:\n",
|
|
" n += y.shape[0]\n",
|
|
" for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
|
|
"idx = torch.argmax(counts)\n",
|
|
"idx, vocab[idx.item()], counts[idx].item()/n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Our first recurrent neural network"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel2(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" h = 0\n",
|
|
" for i in range(3):\n",
|
|
" h = h + self.i_h(x[:,i])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" return self.h_o(h)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.816274</td>\n",
|
|
" <td>1.964143</td>\n",
|
|
" <td>0.460185</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.423805</td>\n",
|
|
" <td>1.739964</td>\n",
|
|
" <td>0.473259</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.430327</td>\n",
|
|
" <td>1.685172</td>\n",
|
|
" <td>0.485382</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.388390</td>\n",
|
|
" <td>1.657033</td>\n",
|
|
" <td>0.470406</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, \n",
|
|
" metrics=accuracy)\n",
|
|
"learn.fit_one_cycle(4, 1e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Improving the RNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Maintaining the state of an RNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel3(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" self.h = 0\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" for i in range(3):\n",
|
|
" self.h = self.h + self.i_h(x[:,i])\n",
|
|
" self.h = F.relu(self.h_h(self.h))\n",
|
|
" out = self.h_o(self.h)\n",
|
|
" self.h = self.h.detach()\n",
|
|
" return out\n",
|
|
" \n",
|
|
" def reset(self): self.h = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(328, 64, 21031)"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"m = len(seqs)//bs\n",
|
|
"m,bs,len(seqs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def group_chunks(ds, bs):\n",
|
|
" m = len(ds) // bs\n",
|
|
" new_ds = L()\n",
|
|
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
|
|
" return new_ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(\n",
|
|
" group_chunks(seqs[:cut], bs), \n",
|
|
" group_chunks(seqs[cut:], bs), \n",
|
|
" bs=bs, drop_last=True, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.677074</td>\n",
|
|
" <td>1.827367</td>\n",
|
|
" <td>0.467548</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.282722</td>\n",
|
|
" <td>1.870913</td>\n",
|
|
" <td>0.388942</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.090705</td>\n",
|
|
" <td>1.651793</td>\n",
|
|
" <td>0.462500</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.005092</td>\n",
|
|
" <td>1.613794</td>\n",
|
|
" <td>0.516587</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0.965975</td>\n",
|
|
" <td>1.560775</td>\n",
|
|
" <td>0.551202</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.916182</td>\n",
|
|
" <td>1.595857</td>\n",
|
|
" <td>0.560577</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.897657</td>\n",
|
|
" <td>1.539733</td>\n",
|
|
" <td>0.574279</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.836274</td>\n",
|
|
" <td>1.585141</td>\n",
|
|
" <td>0.583173</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.805877</td>\n",
|
|
" <td>1.629808</td>\n",
|
|
" <td>0.586779</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.795096</td>\n",
|
|
" <td>1.651267</td>\n",
|
|
" <td>0.588942</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(10, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Creating more signal"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sl = 16\n",
|
|
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
|
|
" for i in range(0,len(nums)-sl-1,sl))\n",
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
|
|
" group_chunks(seqs[cut:], bs),\n",
|
|
" bs=bs, drop_last=True, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n",
|
|
" (#16) ['.','two','.','three','.','four','.','five','.','six'...]]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"[L(vocab[o] for o in s) for s in seqs[0]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel4(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" self.h = 0\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" outs = []\n",
|
|
" for i in range(sl):\n",
|
|
" self.h = self.h + self.i_h(x[:,i])\n",
|
|
" self.h = F.relu(self.h_h(self.h))\n",
|
|
" outs.append(self.h_o(self.h))\n",
|
|
" self.h = self.h.detach()\n",
|
|
" return torch.stack(outs, dim=1)\n",
|
|
" \n",
|
|
" def reset(self): self.h = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def loss_func(inp, targ):\n",
|
|
" return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.103298</td>\n",
|
|
" <td>2.874341</td>\n",
|
|
" <td>0.212565</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.231964</td>\n",
|
|
" <td>1.971280</td>\n",
|
|
" <td>0.462158</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.711358</td>\n",
|
|
" <td>1.813547</td>\n",
|
|
" <td>0.461182</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.448516</td>\n",
|
|
" <td>1.828176</td>\n",
|
|
" <td>0.483236</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.288630</td>\n",
|
|
" <td>1.659564</td>\n",
|
|
" <td>0.520671</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.161470</td>\n",
|
|
" <td>1.714023</td>\n",
|
|
" <td>0.554932</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.055568</td>\n",
|
|
" <td>1.660916</td>\n",
|
|
" <td>0.575033</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.960765</td>\n",
|
|
" <td>1.719624</td>\n",
|
|
" <td>0.591064</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.870153</td>\n",
|
|
" <td>1.839560</td>\n",
|
|
" <td>0.614665</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.808545</td>\n",
|
|
" <td>1.770278</td>\n",
|
|
" <td>0.624349</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.758084</td>\n",
|
|
" <td>1.842931</td>\n",
|
|
" <td>0.610758</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.719320</td>\n",
|
|
" <td>1.799527</td>\n",
|
|
" <td>0.646566</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.683439</td>\n",
|
|
" <td>1.917928</td>\n",
|
|
" <td>0.649821</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.660283</td>\n",
|
|
" <td>1.874712</td>\n",
|
|
" <td>0.628581</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.646154</td>\n",
|
|
" <td>1.877519</td>\n",
|
|
" <td>0.640055</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Multilayer RNNs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel5(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h = torch.zeros(n_layers, bs, n_hidden)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" self.h = h.detach()\n",
|
|
" return self.h_o(res)\n",
|
|
" \n",
|
|
" def reset(self): self.h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.055853</td>\n",
|
|
" <td>2.591640</td>\n",
|
|
" <td>0.437907</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.162359</td>\n",
|
|
" <td>1.787310</td>\n",
|
|
" <td>0.471598</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.710663</td>\n",
|
|
" <td>1.941807</td>\n",
|
|
" <td>0.321777</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.520783</td>\n",
|
|
" <td>1.999726</td>\n",
|
|
" <td>0.312012</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.330846</td>\n",
|
|
" <td>2.012902</td>\n",
|
|
" <td>0.413249</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.163297</td>\n",
|
|
" <td>1.896192</td>\n",
|
|
" <td>0.450684</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.033813</td>\n",
|
|
" <td>2.005209</td>\n",
|
|
" <td>0.434814</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.919090</td>\n",
|
|
" <td>2.047083</td>\n",
|
|
" <td>0.456706</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.822939</td>\n",
|
|
" <td>2.068031</td>\n",
|
|
" <td>0.468831</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.750180</td>\n",
|
|
" <td>2.136064</td>\n",
|
|
" <td>0.475098</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.695120</td>\n",
|
|
" <td>2.139140</td>\n",
|
|
" <td>0.485433</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.655752</td>\n",
|
|
" <td>2.155081</td>\n",
|
|
" <td>0.493652</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.629650</td>\n",
|
|
" <td>2.162583</td>\n",
|
|
" <td>0.498535</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.613583</td>\n",
|
|
" <td>2.171649</td>\n",
|
|
" <td>0.491048</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.604309</td>\n",
|
|
" <td>2.180355</td>\n",
|
|
" <td>0.487874</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel5(len(vocab), 64, 2), \n",
|
|
" loss_func=CrossEntropyLossFlat(), \n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Exploding or disappearing activations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Building an LSTM from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTMCell(Module):\n",
|
|
" def __init__(self, ni, nh):\n",
|
|
" self.forget_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.input_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.cell_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.output_gate = nn.Linear(ni + nh, nh)\n",
|
|
"\n",
|
|
" def forward(self, input, state):\n",
|
|
" h,c = state\n",
|
|
" h = torch.stack([h, input], dim=1)\n",
|
|
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
|
" c = c * forget\n",
|
|
" inp = torch.sigmoid(self.input_gate(h))\n",
|
|
" cell = torch.tanh(self.cell_gate(h))\n",
|
|
" c = c + inp * cell\n",
|
|
" out = torch.sigmoid(self.output_gate(h))\n",
|
|
" h = outgate * torch.tanh(c)\n",
|
|
" return h, (h,c)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTMCell(Module):\n",
|
|
" def __init__(self, ni, nh):\n",
|
|
" self.ih = nn.Linear(ni,4*nh)\n",
|
|
" self.hh = nn.Linear(nh,4*nh)\n",
|
|
"\n",
|
|
" def forward(self, input, state):\n",
|
|
" h,c = state\n",
|
|
" #One big multiplication for all the gates is better than 4 smaller ones\n",
|
|
" gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
|
|
" ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
|
|
" cellgate = gates[3].tanh()\n",
|
|
"\n",
|
|
" c = (forgetgate*c) + (ingate*cellgate)\n",
|
|
" h = outgate * c.tanh()\n",
|
|
" return h, (h,c)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t = torch.arange(0,10); t"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t.chunk(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training a language model using LSTMs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel6(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" self.h = [h_.detach() for h_ in h]\n",
|
|
" return self.h_o(res)\n",
|
|
" \n",
|
|
" def reset(self): \n",
|
|
" for h in self.h: h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.000821</td>\n",
|
|
" <td>2.663942</td>\n",
|
|
" <td>0.438314</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.139642</td>\n",
|
|
" <td>2.184780</td>\n",
|
|
" <td>0.240479</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.607275</td>\n",
|
|
" <td>1.812682</td>\n",
|
|
" <td>0.439779</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.347711</td>\n",
|
|
" <td>1.830982</td>\n",
|
|
" <td>0.497477</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.123113</td>\n",
|
|
" <td>1.937766</td>\n",
|
|
" <td>0.594401</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.852042</td>\n",
|
|
" <td>2.012127</td>\n",
|
|
" <td>0.631592</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.565494</td>\n",
|
|
" <td>1.312742</td>\n",
|
|
" <td>0.725749</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.347445</td>\n",
|
|
" <td>1.297934</td>\n",
|
|
" <td>0.711263</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.208191</td>\n",
|
|
" <td>1.441269</td>\n",
|
|
" <td>0.731201</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.126335</td>\n",
|
|
" <td>1.569952</td>\n",
|
|
" <td>0.737305</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.079761</td>\n",
|
|
" <td>1.427187</td>\n",
|
|
" <td>0.754150</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.052990</td>\n",
|
|
" <td>1.494990</td>\n",
|
|
" <td>0.745117</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.039008</td>\n",
|
|
" <td>1.393731</td>\n",
|
|
" <td>0.757894</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.031502</td>\n",
|
|
" <td>1.373210</td>\n",
|
|
" <td>0.758464</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.028068</td>\n",
|
|
" <td>1.368083</td>\n",
|
|
" <td>0.758464</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel6(len(vocab), 64, 2), \n",
|
|
" loss_func=CrossEntropyLossFlat(), \n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 1e-2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Regularizing an LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Dropout"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Dropout(Module):\n",
|
|
" def __init__(self, p): self.p = p\n",
|
|
" def forward(self, x):\n",
|
|
" if not self.training: return x\n",
|
|
" mask = x.new(*x.shape).bernoulli_(1-p)\n",
|
|
" return x * mask.div_(1-p)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### AR and TAR regularization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training a weight-tied regularized LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel7(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.drop = nn.Dropout(p)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h_o.weight = self.i_h.weight\n",
|
|
" self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" raw,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" out = self.drop(raw)\n",
|
|
" self.h = [h_.detach() for h_ in h]\n",
|
|
" return self.h_o(out),raw,out\n",
|
|
" \n",
|
|
" def reset(self): \n",
|
|
" for h in self.h: h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n",
|
|
" loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n",
|
|
" cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n",
|
|
" loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>2.693885</td>\n",
|
|
" <td>2.013484</td>\n",
|
|
" <td>0.466634</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.685549</td>\n",
|
|
" <td>1.187310</td>\n",
|
|
" <td>0.629313</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>0.973307</td>\n",
|
|
" <td>0.791398</td>\n",
|
|
" <td>0.745605</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0.555823</td>\n",
|
|
" <td>0.640412</td>\n",
|
|
" <td>0.794108</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0.351802</td>\n",
|
|
" <td>0.557247</td>\n",
|
|
" <td>0.836100</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.244986</td>\n",
|
|
" <td>0.594977</td>\n",
|
|
" <td>0.807292</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.192231</td>\n",
|
|
" <td>0.511690</td>\n",
|
|
" <td>0.846761</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.162456</td>\n",
|
|
" <td>0.520370</td>\n",
|
|
" <td>0.858073</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.142664</td>\n",
|
|
" <td>0.525918</td>\n",
|
|
" <td>0.842285</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.128493</td>\n",
|
|
" <td>0.495029</td>\n",
|
|
" <td>0.858073</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.117589</td>\n",
|
|
" <td>0.464236</td>\n",
|
|
" <td>0.867188</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.109808</td>\n",
|
|
" <td>0.466550</td>\n",
|
|
" <td>0.869303</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.104216</td>\n",
|
|
" <td>0.455151</td>\n",
|
|
" <td>0.871826</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.100271</td>\n",
|
|
" <td>0.452659</td>\n",
|
|
" <td>0.873617</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.098121</td>\n",
|
|
" <td>0.458372</td>\n",
|
|
" <td>0.869385</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn.fit_one_cycle(15, 1e-2, wd=0.1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Questionnaire"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. If the dataset for your project is so big and complicated that working with it takes a significant amount of time, what should you do?\n",
|
|
"1. Why do we concatenate the documents in our dataset before creating a language model?\n",
|
|
"1. To use a standard fully connected network to predict the fourth word given the previous three words, what two tweaks do we need to make?\n",
|
|
"1. How can we share a weight matrix across multiple layers in PyTorch?\n",
|
|
"1. Write a module which predicts the third word given the previous two words of a sentence, without peeking.\n",
|
|
"1. What is a recurrent neural network?\n",
|
|
"1. What is hidden state?\n",
|
|
"1. What is the equivalent of hidden state in ` LMModel1`?\n",
|
|
"1. To maintain the state in an RNN why is it important to pass the text to the model in order?\n",
|
|
"1. What is an unrolled representation of an RNN?\n",
|
|
"1. Why can maintaining the hidden state in an RNN lead to memory and performance problems? How do we fix this problem?\n",
|
|
"1. What is BPTT?\n",
|
|
"1. Write code to print out the first few batches of the validation set, including converting the token IDs back into English strings, as we showed for batches of IMDb data in <<chapter_nlp>>.\n",
|
|
"1. What does the `ModelReseter` callback do? Why do we need it?\n",
|
|
"1. What are the downsides of predicting just one output word for each three input words?\n",
|
|
"1. Why do we need a custom loss function for `LMModel4`?\n",
|
|
"1. Why is the training of `LMModel4` unstable?\n",
|
|
"1. In the unrolled representation, we can see that a recurrent neural network actually has many layers. So why do we need to stack RNNs to get better results?\n",
|
|
"1. Draw a representation of a stacked (multilayer) RNN.\n",
|
|
"1. Why should we get better results in an RNN if we call `detach` less often? Why might this not happen in practice with a simple RNN?\n",
|
|
"1. Why can a deep network result in very large or very small activations? Why does this matter?\n",
|
|
"1. In a computer's floating point representation of numbers, which numbers are the most precise?\n",
|
|
"1. Why do vanishing gradients prevent training?\n",
|
|
"1. Why does it help to have two hidden states in the LSTM architecture? What is the purpose of each one?\n",
|
|
"1. What are these two states called in an LSTM?\n",
|
|
"1. What is tanh, and how is it related to sigmoid?\n",
|
|
"1. What is the purpose of this code in `LSTMCell`?: `h = torch.stack([h, input], dim=1)`\n",
|
|
"1. What does `chunk` to in PyTorch?\n",
|
|
"1. Study the refactored version of `LSTMCell` carefully to ensure you understand how and why it does the same thing as the non-refactored version.\n",
|
|
"1. Why can we use a higher learning rate for `LMModel6`?\n",
|
|
"1. What are the three regularisation techniques used in an AWD-LSTM model?\n",
|
|
"1. What is dropout?\n",
|
|
"1. Why do we scale the weights with dropout? Is this applied during training, inference, or both?\n",
|
|
"1. What is the purpose of this line from `Dropout`?: `if not self.training: return x`\n",
|
|
"1. Experiment with `bernoulli_` to understand how it works.\n",
|
|
"1. How do you set your model in training mode in PyTorch? In evaluation mode?\n",
|
|
"1. Write the equation for activation regularization (in maths or code, as you prefer). How is it different to weight decay?\n",
|
|
"1. Write the equation for temporal activation regularization (in maths or code, as you prefer). Why wouldn't we use this for computer vision problems?\n",
|
|
"1. What is \"weight tying\" in a language model?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Further research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. In ` LMModel2` why can `forward` start with `h=0`? Why don't we need to say `h=torch.zeros(…)`?\n",
|
|
"1. Write the code for an LSTM from scratch (but you may refer to <<lstm>>).\n",
|
|
"1. Search on the Internet for the GRU architecture and implement it from scratch, and try training a model. See if you can get the similar results as we saw in this chapter. Compare it to the results of PyTorch's built in GRU module.\n",
|
|
"1. Have a look at the source code for AWD-LSTM in fastai, and try to map each of the lines of code to the concepts shown in this chapter."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"split_at_heading": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|