{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A language model from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.text.all import *\n",
"path = untar_data(URLs.HUMAN_NUMBERS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('train.txt'),Path('valid.txt')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lines = L()\n",
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
"lines"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = ' . '.join([l.strip() for l in lines])\n",
"text[:100]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens = text.split(' ')\n",
"tokens[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab = L(*tokens).unique()\n",
"vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
"nums = L(word2idx[i] for i in tokens)\n",
"nums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our first language model from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
"seqs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs = 64\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our language model in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel1(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
" h = h + self.i_h(x[:,1])\n",
" h = F.relu(self.h_h(h))\n",
" h = h + self.i_h(x[:,2])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.824297 | \n",
" 1.970941 | \n",
" 0.467554 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.386973 | \n",
" 1.823242 | \n",
" 0.467554 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.417556 | \n",
" 1.654497 | \n",
" 0.494414 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.376440 | \n",
" 1.650849 | \n",
" 0.494414 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, \n",
" metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(29), 'thousand', 0.15165200855716662)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n,counts = 0,torch.zeros(len(vocab))\n",
"for x,y in dls.valid:\n",
" n += y.shape[0]\n",
" for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
"idx = torch.argmax(counts)\n",
"idx, vocab[idx.item()], counts[idx].item()/n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our first recurrent neural network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel2(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = 0\n",
" for i in range(3):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.816274 | \n",
" 1.964143 | \n",
" 0.460185 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.423805 | \n",
" 1.739964 | \n",
" 0.473259 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.430327 | \n",
" 1.685172 | \n",
" 0.485382 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.388390 | \n",
" 1.657033 | \n",
" 0.470406 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, \n",
" metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Improving the RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Maintaining the state of an RNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel3(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" for i in range(3):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" out = self.h_o(self.h)\n",
" self.h = self.h.detach()\n",
" return out\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(328, 64, 21031)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = len(seqs)//bs\n",
"m,bs,len(seqs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def group_chunks(ds, bs):\n",
" m = len(ds) // bs\n",
" new_ds = L()\n",
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
" return new_ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(\n",
" group_chunks(seqs[:cut], bs), \n",
" group_chunks(seqs[cut:], bs), \n",
" bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.677074 | \n",
" 1.827367 | \n",
" 0.467548 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.282722 | \n",
" 1.870913 | \n",
" 0.388942 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.090705 | \n",
" 1.651793 | \n",
" 0.462500 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.005092 | \n",
" 1.613794 | \n",
" 0.516587 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.965975 | \n",
" 1.560775 | \n",
" 0.551202 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.916182 | \n",
" 1.595857 | \n",
" 0.560577 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.897657 | \n",
" 1.539733 | \n",
" 0.574279 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.836274 | \n",
" 1.585141 | \n",
" 0.583173 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.805877 | \n",
" 1.629808 | \n",
" 0.586779 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.795096 | \n",
" 1.651267 | \n",
" 0.588942 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(10, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating more signal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sl = 16\n",
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
" for i in range(0,len(nums)-sl-1,sl))\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
" group_chunks(seqs[cut:], bs),\n",
" bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n",
" (#16) ['.','two','.','three','.','four','.','five','.','six'...]]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[L(vocab[o] for o in s) for s in seqs[0]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel4(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" outs = []\n",
" for i in range(sl):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" outs.append(self.h_o(self.h))\n",
" self.h = self.h.detach()\n",
" return torch.stack(outs, dim=1)\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss_func(inp, targ):\n",
" return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.103298 | \n",
" 2.874341 | \n",
" 0.212565 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 1 | \n",
" 2.231964 | \n",
" 1.971280 | \n",
" 0.462158 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.711358 | \n",
" 1.813547 | \n",
" 0.461182 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.448516 | \n",
" 1.828176 | \n",
" 0.483236 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.288630 | \n",
" 1.659564 | \n",
" 0.520671 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.161470 | \n",
" 1.714023 | \n",
" 0.554932 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.055568 | \n",
" 1.660916 | \n",
" 0.575033 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.960765 | \n",
" 1.719624 | \n",
" 0.591064 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.870153 | \n",
" 1.839560 | \n",
" 0.614665 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.808545 | \n",
" 1.770278 | \n",
" 0.624349 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.758084 | \n",
" 1.842931 | \n",
" 0.610758 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.719320 | \n",
" 1.799527 | \n",
" 0.646566 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.683439 | \n",
" 1.917928 | \n",
" 0.649821 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.660283 | \n",
" 1.874712 | \n",
" 0.628581 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.646154 | \n",
" 1.877519 | \n",
" 0.640055 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(15, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multilayer RNNs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel5(Module):\n",
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = torch.zeros(n_layers, bs, n_hidden)\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(res)\n",
" \n",
" def reset(self): self.h.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.055853 | \n",
" 2.591640 | \n",
" 0.437907 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 1 | \n",
" 2.162359 | \n",
" 1.787310 | \n",
" 0.471598 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.710663 | \n",
" 1.941807 | \n",
" 0.321777 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.520783 | \n",
" 1.999726 | \n",
" 0.312012 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.330846 | \n",
" 2.012902 | \n",
" 0.413249 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.163297 | \n",
" 1.896192 | \n",
" 0.450684 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.033813 | \n",
" 2.005209 | \n",
" 0.434814 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.919090 | \n",
" 2.047083 | \n",
" 0.456706 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.822939 | \n",
" 2.068031 | \n",
" 0.468831 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.750180 | \n",
" 2.136064 | \n",
" 0.475098 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.695120 | \n",
" 2.139140 | \n",
" 0.485433 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.655752 | \n",
" 2.155081 | \n",
" 0.493652 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.629650 | \n",
" 2.162583 | \n",
" 0.498535 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.613583 | \n",
" 2.171649 | \n",
" 0.491048 | \n",
" 00:01 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.604309 | \n",
" 2.180355 | \n",
" 0.487874 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel5(len(vocab), 64, 2), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(15, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exploding or disappearing activations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LSTM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Building an LSTM from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LSTMCell(Module):\n",
" def __init__(self, ni, nh):\n",
" self.forget_gate = nn.Linear(ni + nh, nh)\n",
" self.input_gate = nn.Linear(ni + nh, nh)\n",
" self.cell_gate = nn.Linear(ni + nh, nh)\n",
" self.output_gate = nn.Linear(ni + nh, nh)\n",
"\n",
" def forward(self, input, state):\n",
" h,c = state\n",
" h = torch.stack([h, input], dim=1)\n",
" forget = torch.sigmoid(self.forget_gate(h))\n",
" c = c * forget\n",
" inp = torch.sigmoid(self.input_gate(h))\n",
" cell = torch.tanh(self.cell_gate(h))\n",
" c = c + inp * cell\n",
" out = torch.sigmoid(self.output_gate(h))\n",
" h = outgate * torch.tanh(c)\n",
" return h, (h,c)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LSTMCell(Module):\n",
" def __init__(self, ni, nh):\n",
" self.ih = nn.Linear(ni,4*nh)\n",
" self.hh = nn.Linear(nh,4*nh)\n",
"\n",
" def forward(self, input, state):\n",
" h,c = state\n",
" #One big multiplication for all the gates is better than 4 smaller ones\n",
" gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
" ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
" cellgate = gates[3].tanh()\n",
"\n",
" c = (forgetgate*c) + (ingate*cellgate)\n",
" h = outgate * c.tanh()\n",
" return h, (h,c)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = torch.arange(0,10); t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t.chunk(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training a language model using LSTMs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel6(Module):\n",
" def __init__(self, vocab_sz, n_hidden, n_layers):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = [h_.detach() for h_ in h]\n",
" return self.h_o(res)\n",
" \n",
" def reset(self): \n",
" for h in self.h: h.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.000821 | \n",
" 2.663942 | \n",
" 0.438314 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 1 | \n",
" 2.139642 | \n",
" 2.184780 | \n",
" 0.240479 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.607275 | \n",
" 1.812682 | \n",
" 0.439779 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.347711 | \n",
" 1.830982 | \n",
" 0.497477 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.123113 | \n",
" 1.937766 | \n",
" 0.594401 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.852042 | \n",
" 2.012127 | \n",
" 0.631592 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.565494 | \n",
" 1.312742 | \n",
" 0.725749 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.347445 | \n",
" 1.297934 | \n",
" 0.711263 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.208191 | \n",
" 1.441269 | \n",
" 0.731201 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.126335 | \n",
" 1.569952 | \n",
" 0.737305 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.079761 | \n",
" 1.427187 | \n",
" 0.754150 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.052990 | \n",
" 1.494990 | \n",
" 0.745117 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.039008 | \n",
" 1.393731 | \n",
" 0.757894 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.031502 | \n",
" 1.373210 | \n",
" 0.758464 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.028068 | \n",
" 1.368083 | \n",
" 0.758464 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel6(len(vocab), 64, 2), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(15, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regularizing an LSTM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dropout"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Dropout(Module):\n",
" def __init__(self, p): self.p = p\n",
" def forward(self, x):\n",
" if not self.training: return x\n",
" mask = x.new(*x.shape).bernoulli_(1-p)\n",
" return x * mask.div_(1-p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### AR and TAR regularization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training a weight-tied regularized LSTM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel7(Module):\n",
" def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
" \n",
" def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n",
" out = self.drop(raw)\n",
" self.h = [h_.detach() for h_ in h]\n",
" return self.h_o(out),raw,out\n",
" \n",
" def reset(self): \n",
" for h in self.h: h.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n",
" loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n",
" cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n",
" loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2.693885 | \n",
" 2.013484 | \n",
" 0.466634 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.685549 | \n",
" 1.187310 | \n",
" 0.629313 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.973307 | \n",
" 0.791398 | \n",
" 0.745605 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.555823 | \n",
" 0.640412 | \n",
" 0.794108 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.351802 | \n",
" 0.557247 | \n",
" 0.836100 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.244986 | \n",
" 0.594977 | \n",
" 0.807292 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.192231 | \n",
" 0.511690 | \n",
" 0.846761 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.162456 | \n",
" 0.520370 | \n",
" 0.858073 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.142664 | \n",
" 0.525918 | \n",
" 0.842285 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.128493 | \n",
" 0.495029 | \n",
" 0.858073 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.117589 | \n",
" 0.464236 | \n",
" 0.867188 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.109808 | \n",
" 0.466550 | \n",
" 0.869303 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.104216 | \n",
" 0.455151 | \n",
" 0.871826 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.100271 | \n",
" 0.452659 | \n",
" 0.873617 | \n",
" 00:02 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.098121 | \n",
" 0.458372 | \n",
" 0.869385 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(15, 1e-2, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}