{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# A language model from scratch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.text.all import *\n", "path = untar_data(URLs.HUMAN_NUMBERS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "Path.BASE_PATH = path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('train.txt'),Path('valid.txt')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lines = L()\n", "with open(path/'train.txt') as f: lines += L(*f.readlines())\n", "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n", "lines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = ' . '.join([l.strip() for l in lines])\n", "text[:100]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens = text.split(' ')\n", "tokens[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab = L(*tokens).unique()\n", "vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#63095) [0,1,2,1,3,1,4,1,5,1...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word2idx = {w:i for i,w in enumerate(vocab)}\n", "nums = L(word2idx[i] for i in tokens)\n", "nums" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our first language model from scratch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n", "seqs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs = 64\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Our language model in PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel1(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " \n", " def forward(self, x):\n", " h = F.relu(self.h_h(self.i_h(x[:,0])))\n", " h = h + self.i_h(x[:,1])\n", " h = F.relu(self.h_h(h))\n", " h = h + self.i_h(x[:,2])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8242971.9709410.46755400:02
11.3869731.8232420.46755400:02
21.4175561.6544970.49441400:02
31.3764401.6508490.49441400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, \n", " metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(29), 'thousand', 0.15165200855716662)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n,counts = 0,torch.zeros(len(vocab))\n", "for x,y in dls.valid:\n", " n += y.shape[0]\n", " for i in range_of(vocab): counts[i] += (y==i).long().sum()\n", "idx = torch.argmax(counts)\n", "idx, vocab[idx.item()], counts[idx].item()/n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Our first recurrent neural network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel2(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " \n", " def forward(self, x):\n", " h = 0\n", " for i in range(3):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8162741.9641430.46018500:02
11.4238051.7399640.47325900:02
21.4303271.6851720.48538200:02
31.3883901.6570330.47040600:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, \n", " metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Improving the RNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Maintaining the state of an RNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel3(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", " \n", " def forward(self, x):\n", " for i in range(3):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " out = self.h_o(self.h)\n", " self.h = self.h.detach()\n", " return out\n", " \n", " def reset(self): self.h = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(328, 64, 21031)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = len(seqs)//bs\n", "m,bs,len(seqs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def group_chunks(ds, bs):\n", " m = len(ds) // bs\n", " new_ds = L()\n", " for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n", " return new_ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(\n", " group_chunks(seqs[:cut], bs), \n", " group_chunks(seqs[cut:], bs), \n", " bs=bs, drop_last=True, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.6770741.8273670.46754800:02
11.2827221.8709130.38894200:02
21.0907051.6517930.46250000:02
31.0050921.6137940.51658700:02
40.9659751.5607750.55120200:02
50.9161821.5958570.56057700:02
60.8976571.5397330.57427900:02
70.8362741.5851410.58317300:02
80.8058771.6298080.58677900:02
90.7950961.6512670.58894200:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n", " metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(10, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating more signal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sl = 16\n", "seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n", " for i in range(0,len(nums)-sl-1,sl))\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n", " group_chunks(seqs[cut:], bs),\n", " bs=bs, drop_last=True, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n", " (#16) ['.','two','.','three','.','four','.','five','.','six'...]]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[L(vocab[o] for o in s) for s in seqs[0]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel4(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", " \n", " def forward(self, x):\n", " outs = []\n", " for i in range(sl):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " outs.append(self.h_o(self.h))\n", " self.h = self.h.detach()\n", " return torch.stack(outs, dim=1)\n", " \n", " def reset(self): self.h = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_func(inp, targ):\n", " return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.1032982.8743410.21256500:01
12.2319641.9712800.46215800:01
21.7113581.8135470.46118200:01
31.4485161.8281760.48323600:01
41.2886301.6595640.52067100:01
51.1614701.7140230.55493200:01
61.0555681.6609160.57503300:01
70.9607651.7196240.59106400:01
80.8701531.8395600.61466500:01
90.8085451.7702780.62434900:01
100.7580841.8429310.61075800:01
110.7193201.7995270.64656600:01
120.6834391.9179280.64982100:01
130.6602831.8747120.62858100:01
140.6461541.8775190.64005500:01
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n", " metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(15, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multilayer RNNs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel5(Module):\n", " def __init__(self, vocab_sz, n_hidden, n_layers):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h = torch.zeros(n_layers, bs, n_hidden)\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(res)\n", " \n", " def reset(self): self.h.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.0558532.5916400.43790700:01
12.1623591.7873100.47159800:01
21.7106631.9418070.32177700:01
31.5207831.9997260.31201200:01
41.3308462.0129020.41324900:01
51.1632971.8961920.45068400:01
61.0338132.0052090.43481400:01
70.9190902.0470830.45670600:01
80.8229392.0680310.46883100:01
90.7501802.1360640.47509800:01
100.6951202.1391400.48543300:01
110.6557522.1550810.49365200:01
120.6296502.1625830.49853500:01
130.6135832.1716490.49104800:01
140.6043092.1803550.48787400:01
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel5(len(vocab), 64, 2), \n", " loss_func=CrossEntropyLossFlat(), \n", " metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(15, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exploding or disappearing activations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LSTM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Building an LSTM from scratch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LSTMCell(Module):\n", " def __init__(self, ni, nh):\n", " self.forget_gate = nn.Linear(ni + nh, nh)\n", " self.input_gate = nn.Linear(ni + nh, nh)\n", " self.cell_gate = nn.Linear(ni + nh, nh)\n", " self.output_gate = nn.Linear(ni + nh, nh)\n", "\n", " def forward(self, input, state):\n", " h,c = state\n", " h = torch.stack([h, input], dim=1)\n", " forget = torch.sigmoid(self.forget_gate(h))\n", " c = c * forget\n", " inp = torch.sigmoid(self.input_gate(h))\n", " cell = torch.tanh(self.cell_gate(h))\n", " c = c + inp * cell\n", " out = torch.sigmoid(self.output_gate(h))\n", " h = outgate * torch.tanh(c)\n", " return h, (h,c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LSTMCell(Module):\n", " def __init__(self, ni, nh):\n", " self.ih = nn.Linear(ni,4*nh)\n", " self.hh = nn.Linear(nh,4*nh)\n", "\n", " def forward(self, input, state):\n", " h,c = state\n", " #One big multiplication for all the gates is better than 4 smaller ones\n", " gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n", " ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n", " cellgate = gates[3].tanh()\n", "\n", " c = (forgetgate*c) + (ingate*cellgate)\n", " h = outgate * c.tanh()\n", " return h, (h,c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = torch.arange(0,10); t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.chunk(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training a language model using LSTMs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel6(Module):\n", " def __init__(self, vocab_sz, n_hidden, n_layers):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = [h_.detach() for h_ in h]\n", " return self.h_o(res)\n", " \n", " def reset(self): \n", " for h in self.h: h.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.0008212.6639420.43831400:02
12.1396422.1847800.24047900:02
21.6072751.8126820.43977900:02
31.3477111.8309820.49747700:02
41.1231131.9377660.59440100:02
50.8520422.0121270.63159200:02
60.5654941.3127420.72574900:02
70.3474451.2979340.71126300:02
80.2081911.4412690.73120100:02
90.1263351.5699520.73730500:02
100.0797611.4271870.75415000:02
110.0529901.4949900.74511700:02
120.0390081.3937310.75789400:02
130.0315021.3732100.75846400:02
140.0280681.3680830.75846400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel6(len(vocab), 64, 2), \n", " loss_func=CrossEntropyLossFlat(), \n", " metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(15, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regularizing an LSTM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dropout" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Dropout(Module):\n", " def __init__(self, p): self.p = p\n", " def forward(self, x):\n", " if not self.training: return x\n", " mask = x.new(*x.shape).bernoulli_(1-p)\n", " return x * mask.div_(1-p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### AR and TAR regularization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training a weight-tied regularized LSTM" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel7(Module):\n", " def __init__(self, vocab_sz, n_hidden, n_layers, p):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.drop = nn.Dropout(p)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o.weight = self.i_h.weight\n", " self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " \n", " def forward(self, x):\n", " raw,h = self.rnn(self.i_h(x), self.h)\n", " out = self.drop(raw)\n", " self.h = [h_.detach() for h_ in h]\n", " return self.h_o(out),raw,out\n", " \n", " def reset(self): \n", " for h in self.h: h.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n", " loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n", " cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n", " loss_func=CrossEntropyLossFlat(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
02.6938852.0134840.46663400:02
11.6855491.1873100.62931300:02
20.9733070.7913980.74560500:02
30.5558230.6404120.79410800:02
40.3518020.5572470.83610000:02
50.2449860.5949770.80729200:02
60.1922310.5116900.84676100:02
70.1624560.5203700.85807300:02
80.1426640.5259180.84228500:02
90.1284930.4950290.85807300:02
100.1175890.4642360.86718800:02
110.1098080.4665500.86930300:02
120.1042160.4551510.87182600:02
130.1002710.4526590.87361700:02
140.0981210.4583720.86938500:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(15, 1e-2, wd=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Questionnaire" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Further research" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }