mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 18:00:48 +00:00
1591 lines
41 KiB
Plaintext
1591 lines
41 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"from utils import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# A language model from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from fastai2.text.all import *\n",
|
|
"path = untar_data(URLs.HUMAN_NUMBERS)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"Path.BASE_PATH = path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#2) [Path('train.txt'),Path('valid.txt')]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"path.ls()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lines = L()\n",
|
|
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
|
|
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
|
|
"lines"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"text = ' . '.join([l.strip() for l in lines])\n",
|
|
"text[:100]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"tokens = text.split(' ')\n",
|
|
"tokens[:10]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"vocab = L(*tokens).unique()\n",
|
|
"vocab"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
|
|
"nums = L(word2idx[i] for i in tokens)\n",
|
|
"nums"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Our first language model from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
|
|
"seqs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"bs = 64\n",
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Our language model in PyTorch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel1(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
|
|
" h = h + self.i_h(x[:,1])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" h = h + self.i_h(x[:,2])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" return self.h_o(h)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.824297</td>\n",
|
|
" <td>1.970941</td>\n",
|
|
" <td>0.467554</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.386973</td>\n",
|
|
" <td>1.823242</td>\n",
|
|
" <td>0.467554</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.417556</td>\n",
|
|
" <td>1.654497</td>\n",
|
|
" <td>0.494414</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.376440</td>\n",
|
|
" <td>1.650849</td>\n",
|
|
" <td>0.494414</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, \n",
|
|
" metrics=accuracy)\n",
|
|
"learn.fit_one_cycle(4, 1e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor(29), 'thousand', 0.15165200855716662)"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"n,counts = 0,torch.zeros(len(vocab))\n",
|
|
"for x,y in dls.valid:\n",
|
|
" n += y.shape[0]\n",
|
|
" for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
|
|
"idx = torch.argmax(counts)\n",
|
|
"idx, vocab[idx.item()], counts[idx].item()/n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Our first recurrent neural network"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel2(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" h = 0\n",
|
|
" for i in range(3):\n",
|
|
" h = h + self.i_h(x[:,i])\n",
|
|
" h = F.relu(self.h_h(h))\n",
|
|
" return self.h_o(h)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.816274</td>\n",
|
|
" <td>1.964143</td>\n",
|
|
" <td>0.460185</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.423805</td>\n",
|
|
" <td>1.739964</td>\n",
|
|
" <td>0.473259</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.430327</td>\n",
|
|
" <td>1.685172</td>\n",
|
|
" <td>0.485382</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.388390</td>\n",
|
|
" <td>1.657033</td>\n",
|
|
" <td>0.470406</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, \n",
|
|
" metrics=accuracy)\n",
|
|
"learn.fit_one_cycle(4, 1e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Improving the RNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Maintaining the state of an RNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel3(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" self.h = 0\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" for i in range(3):\n",
|
|
" self.h = self.h + self.i_h(x[:,i])\n",
|
|
" self.h = F.relu(self.h_h(self.h))\n",
|
|
" out = self.h_o(self.h)\n",
|
|
" self.h = self.h.detach()\n",
|
|
" return out\n",
|
|
" \n",
|
|
" def reset(self): self.h = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(328, 64, 21031)"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"m = len(seqs)//bs\n",
|
|
"m,bs,len(seqs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def group_chunks(ds, bs):\n",
|
|
" m = len(ds) // bs\n",
|
|
" new_ds = L()\n",
|
|
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
|
|
" return new_ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(\n",
|
|
" group_chunks(seqs[:cut], bs), \n",
|
|
" group_chunks(seqs[cut:], bs), \n",
|
|
" bs=bs, drop_last=True, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>1.677074</td>\n",
|
|
" <td>1.827367</td>\n",
|
|
" <td>0.467548</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.282722</td>\n",
|
|
" <td>1.870913</td>\n",
|
|
" <td>0.388942</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.090705</td>\n",
|
|
" <td>1.651793</td>\n",
|
|
" <td>0.462500</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.005092</td>\n",
|
|
" <td>1.613794</td>\n",
|
|
" <td>0.516587</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0.965975</td>\n",
|
|
" <td>1.560775</td>\n",
|
|
" <td>0.551202</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.916182</td>\n",
|
|
" <td>1.595857</td>\n",
|
|
" <td>0.560577</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.897657</td>\n",
|
|
" <td>1.539733</td>\n",
|
|
" <td>0.574279</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.836274</td>\n",
|
|
" <td>1.585141</td>\n",
|
|
" <td>0.583173</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.805877</td>\n",
|
|
" <td>1.629808</td>\n",
|
|
" <td>0.586779</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.795096</td>\n",
|
|
" <td>1.651267</td>\n",
|
|
" <td>0.588942</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(10, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Creating more signal"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sl = 16\n",
|
|
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
|
|
" for i in range(0,len(nums)-sl-1,sl))\n",
|
|
"cut = int(len(seqs) * 0.8)\n",
|
|
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
|
|
" group_chunks(seqs[cut:], bs),\n",
|
|
" bs=bs, drop_last=True, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n",
|
|
" (#16) ['.','two','.','three','.','four','.','five','.','six'...]]"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"[L(vocab[o] for o in s) for s in seqs[0]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel4(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
|
|
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
|
|
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
|
|
" self.h = 0\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" outs = []\n",
|
|
" for i in range(sl):\n",
|
|
" self.h = self.h + self.i_h(x[:,i])\n",
|
|
" self.h = F.relu(self.h_h(self.h))\n",
|
|
" outs.append(self.h_o(self.h))\n",
|
|
" self.h = self.h.detach()\n",
|
|
" return torch.stack(outs, dim=1)\n",
|
|
" \n",
|
|
" def reset(self): self.h = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def loss_func(inp, targ):\n",
|
|
" return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.103298</td>\n",
|
|
" <td>2.874341</td>\n",
|
|
" <td>0.212565</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.231964</td>\n",
|
|
" <td>1.971280</td>\n",
|
|
" <td>0.462158</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.711358</td>\n",
|
|
" <td>1.813547</td>\n",
|
|
" <td>0.461182</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.448516</td>\n",
|
|
" <td>1.828176</td>\n",
|
|
" <td>0.483236</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.288630</td>\n",
|
|
" <td>1.659564</td>\n",
|
|
" <td>0.520671</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.161470</td>\n",
|
|
" <td>1.714023</td>\n",
|
|
" <td>0.554932</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.055568</td>\n",
|
|
" <td>1.660916</td>\n",
|
|
" <td>0.575033</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.960765</td>\n",
|
|
" <td>1.719624</td>\n",
|
|
" <td>0.591064</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.870153</td>\n",
|
|
" <td>1.839560</td>\n",
|
|
" <td>0.614665</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.808545</td>\n",
|
|
" <td>1.770278</td>\n",
|
|
" <td>0.624349</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.758084</td>\n",
|
|
" <td>1.842931</td>\n",
|
|
" <td>0.610758</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.719320</td>\n",
|
|
" <td>1.799527</td>\n",
|
|
" <td>0.646566</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.683439</td>\n",
|
|
" <td>1.917928</td>\n",
|
|
" <td>0.649821</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.660283</td>\n",
|
|
" <td>1.874712</td>\n",
|
|
" <td>0.628581</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.646154</td>\n",
|
|
" <td>1.877519</td>\n",
|
|
" <td>0.640055</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Multilayer RNNs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel5(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h = torch.zeros(n_layers, bs, n_hidden)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" self.h = h.detach()\n",
|
|
" return self.h_o(res)\n",
|
|
" \n",
|
|
" def reset(self): self.h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.055853</td>\n",
|
|
" <td>2.591640</td>\n",
|
|
" <td>0.437907</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.162359</td>\n",
|
|
" <td>1.787310</td>\n",
|
|
" <td>0.471598</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.710663</td>\n",
|
|
" <td>1.941807</td>\n",
|
|
" <td>0.321777</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.520783</td>\n",
|
|
" <td>1.999726</td>\n",
|
|
" <td>0.312012</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.330846</td>\n",
|
|
" <td>2.012902</td>\n",
|
|
" <td>0.413249</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1.163297</td>\n",
|
|
" <td>1.896192</td>\n",
|
|
" <td>0.450684</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>1.033813</td>\n",
|
|
" <td>2.005209</td>\n",
|
|
" <td>0.434814</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.919090</td>\n",
|
|
" <td>2.047083</td>\n",
|
|
" <td>0.456706</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.822939</td>\n",
|
|
" <td>2.068031</td>\n",
|
|
" <td>0.468831</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.750180</td>\n",
|
|
" <td>2.136064</td>\n",
|
|
" <td>0.475098</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.695120</td>\n",
|
|
" <td>2.139140</td>\n",
|
|
" <td>0.485433</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.655752</td>\n",
|
|
" <td>2.155081</td>\n",
|
|
" <td>0.493652</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.629650</td>\n",
|
|
" <td>2.162583</td>\n",
|
|
" <td>0.498535</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.613583</td>\n",
|
|
" <td>2.171649</td>\n",
|
|
" <td>0.491048</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.604309</td>\n",
|
|
" <td>2.180355</td>\n",
|
|
" <td>0.487874</td>\n",
|
|
" <td>00:01</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel5(len(vocab), 64, 2), \n",
|
|
" loss_func=CrossEntropyLossFlat(), \n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 3e-3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Exploding or disappearing activations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Building an LSTM from scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTMCell(Module):\n",
|
|
" def __init__(self, ni, nh):\n",
|
|
" self.forget_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.input_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.cell_gate = nn.Linear(ni + nh, nh)\n",
|
|
" self.output_gate = nn.Linear(ni + nh, nh)\n",
|
|
"\n",
|
|
" def forward(self, input, state):\n",
|
|
" h,c = state\n",
|
|
" h = torch.stack([h, input], dim=1)\n",
|
|
" forget = torch.sigmoid(self.forget_gate(h))\n",
|
|
" c = c * forget\n",
|
|
" inp = torch.sigmoid(self.input_gate(h))\n",
|
|
" cell = torch.tanh(self.cell_gate(h))\n",
|
|
" c = c + inp * cell\n",
|
|
" out = torch.sigmoid(self.output_gate(h))\n",
|
|
" h = outgate * torch.tanh(c)\n",
|
|
" return h, (h,c)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LSTMCell(Module):\n",
|
|
" def __init__(self, ni, nh):\n",
|
|
" self.ih = nn.Linear(ni,4*nh)\n",
|
|
" self.hh = nn.Linear(nh,4*nh)\n",
|
|
"\n",
|
|
" def forward(self, input, state):\n",
|
|
" h,c = state\n",
|
|
" #One big multiplication for all the gates is better than 4 smaller ones\n",
|
|
" gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
|
|
" ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
|
|
" cellgate = gates[3].tanh()\n",
|
|
"\n",
|
|
" c = (forgetgate*c) + (ingate*cellgate)\n",
|
|
" h = outgate * c.tanh()\n",
|
|
" return h, (h,c)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t = torch.arange(0,10); t"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))"
|
|
]
|
|
},
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"t.chunk(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training a language model using LSTMs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel6(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" res,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" self.h = [h_.detach() for h_ in h]\n",
|
|
" return self.h_o(res)\n",
|
|
" \n",
|
|
" def reset(self): \n",
|
|
" for h in self.h: h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3.000821</td>\n",
|
|
" <td>2.663942</td>\n",
|
|
" <td>0.438314</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>2.139642</td>\n",
|
|
" <td>2.184780</td>\n",
|
|
" <td>0.240479</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1.607275</td>\n",
|
|
" <td>1.812682</td>\n",
|
|
" <td>0.439779</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1.347711</td>\n",
|
|
" <td>1.830982</td>\n",
|
|
" <td>0.497477</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1.123113</td>\n",
|
|
" <td>1.937766</td>\n",
|
|
" <td>0.594401</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.852042</td>\n",
|
|
" <td>2.012127</td>\n",
|
|
" <td>0.631592</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.565494</td>\n",
|
|
" <td>1.312742</td>\n",
|
|
" <td>0.725749</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.347445</td>\n",
|
|
" <td>1.297934</td>\n",
|
|
" <td>0.711263</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.208191</td>\n",
|
|
" <td>1.441269</td>\n",
|
|
" <td>0.731201</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.126335</td>\n",
|
|
" <td>1.569952</td>\n",
|
|
" <td>0.737305</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.079761</td>\n",
|
|
" <td>1.427187</td>\n",
|
|
" <td>0.754150</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.052990</td>\n",
|
|
" <td>1.494990</td>\n",
|
|
" <td>0.745117</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.039008</td>\n",
|
|
" <td>1.393731</td>\n",
|
|
" <td>0.757894</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.031502</td>\n",
|
|
" <td>1.373210</td>\n",
|
|
" <td>0.758464</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.028068</td>\n",
|
|
" <td>1.368083</td>\n",
|
|
" <td>0.758464</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel6(len(vocab), 64, 2), \n",
|
|
" loss_func=CrossEntropyLossFlat(), \n",
|
|
" metrics=accuracy, cbs=ModelReseter)\n",
|
|
"learn.fit_one_cycle(15, 1e-2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Regularizing an LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Dropout"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Dropout(Module):\n",
|
|
" def __init__(self, p): self.p = p\n",
|
|
" def forward(self, x):\n",
|
|
" if not self.training: return x\n",
|
|
" mask = x.new(*x.shape).bernoulli_(1-p)\n",
|
|
" return x * mask.div_(1-p)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### AR and TAR regularization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training a weight-tied regularized LSTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LMModel7(Module):\n",
|
|
" def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
|
|
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
|
|
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
|
|
" self.drop = nn.Dropout(p)\n",
|
|
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
|
|
" self.h_o.weight = self.i_h.weight\n",
|
|
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" raw,h = self.rnn(self.i_h(x), self.h)\n",
|
|
" out = self.drop(raw)\n",
|
|
" self.h = [h_.detach() for h_ in h]\n",
|
|
" return self.h_o(out),raw,out\n",
|
|
" \n",
|
|
" def reset(self): \n",
|
|
" for h in self.h: h.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n",
|
|
" loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n",
|
|
" cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n",
|
|
" loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>epoch</th>\n",
|
|
" <th>train_loss</th>\n",
|
|
" <th>valid_loss</th>\n",
|
|
" <th>accuracy</th>\n",
|
|
" <th>time</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>0</td>\n",
|
|
" <td>2.693885</td>\n",
|
|
" <td>2.013484</td>\n",
|
|
" <td>0.466634</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1.685549</td>\n",
|
|
" <td>1.187310</td>\n",
|
|
" <td>0.629313</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2</td>\n",
|
|
" <td>0.973307</td>\n",
|
|
" <td>0.791398</td>\n",
|
|
" <td>0.745605</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3</td>\n",
|
|
" <td>0.555823</td>\n",
|
|
" <td>0.640412</td>\n",
|
|
" <td>0.794108</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>4</td>\n",
|
|
" <td>0.351802</td>\n",
|
|
" <td>0.557247</td>\n",
|
|
" <td>0.836100</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0.244986</td>\n",
|
|
" <td>0.594977</td>\n",
|
|
" <td>0.807292</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>6</td>\n",
|
|
" <td>0.192231</td>\n",
|
|
" <td>0.511690</td>\n",
|
|
" <td>0.846761</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>7</td>\n",
|
|
" <td>0.162456</td>\n",
|
|
" <td>0.520370</td>\n",
|
|
" <td>0.858073</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>8</td>\n",
|
|
" <td>0.142664</td>\n",
|
|
" <td>0.525918</td>\n",
|
|
" <td>0.842285</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>9</td>\n",
|
|
" <td>0.128493</td>\n",
|
|
" <td>0.495029</td>\n",
|
|
" <td>0.858073</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>10</td>\n",
|
|
" <td>0.117589</td>\n",
|
|
" <td>0.464236</td>\n",
|
|
" <td>0.867188</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>11</td>\n",
|
|
" <td>0.109808</td>\n",
|
|
" <td>0.466550</td>\n",
|
|
" <td>0.869303</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>12</td>\n",
|
|
" <td>0.104216</td>\n",
|
|
" <td>0.455151</td>\n",
|
|
" <td>0.871826</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>13</td>\n",
|
|
" <td>0.100271</td>\n",
|
|
" <td>0.452659</td>\n",
|
|
" <td>0.873617</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>14</td>\n",
|
|
" <td>0.098121</td>\n",
|
|
" <td>0.458372</td>\n",
|
|
" <td>0.869385</td>\n",
|
|
" <td>00:02</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"learn.fit_one_cycle(15, 1e-2, wd=0.1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Questionnaire"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Further research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"split_at_heading": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|