{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Language Model from Scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text.all import *\n",
    "path = untar_data(URLs.HUMAN_NUMBERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "Path.BASE_PATH = path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#2) [Path('train.txt'),Path('valid.txt')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines = L()\n",
    "with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
    "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
    "lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = ' . '.join([l.strip() for l in lines])\n",
    "text[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens = text.split(' ')\n",
    "tokens[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = L(*tokens).unique()\n",
    "vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#63095) [0,1,2,1,3,1,4,1,5,1...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2idx = {w:i for i,w in enumerate(vocab)}\n",
    "nums = L(word2idx[i] for i in tokens)\n",
    "nums"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our First Language Model from Scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
    "seqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64\n",
    "cut = int(len(seqs) * 0.8)\n",
    "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our Language Model in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel1(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)  \n",
    "        self.h_h = nn.Linear(n_hidden, n_hidden)     \n",
    "        self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
    "        h = h + self.i_h(x[:,1])\n",
    "        h = F.relu(self.h_h(h))\n",
    "        h = h + self.i_h(x[:,2])\n",
    "        h = F.relu(self.h_h(h))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.824297</td>\n",
       "      <td>1.970941</td>\n",
       "      <td>0.467554</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.386973</td>\n",
       "      <td>1.823242</td>\n",
       "      <td>0.467554</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.417556</td>\n",
       "      <td>1.654497</td>\n",
       "      <td>0.494414</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.376440</td>\n",
       "      <td>1.650849</td>\n",
       "      <td>0.494414</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, \n",
    "                metrics=accuracy)\n",
    "learn.fit_one_cycle(4, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(29), 'thousand', 0.15165200855716662)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n,counts = 0,torch.zeros(len(vocab))\n",
    "for x,y in dls.valid:\n",
    "    n += y.shape[0]\n",
    "    for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
    "idx = torch.argmax(counts)\n",
    "idx, vocab[idx.item()], counts[idx].item()/n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our First Recurrent Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel2(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)  \n",
    "        self.h_h = nn.Linear(n_hidden, n_hidden)     \n",
    "        self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = 0\n",
    "        for i in range(3):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = F.relu(self.h_h(h))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.816274</td>\n",
       "      <td>1.964143</td>\n",
       "      <td>0.460185</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.423805</td>\n",
       "      <td>1.739964</td>\n",
       "      <td>0.473259</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.430327</td>\n",
       "      <td>1.685172</td>\n",
       "      <td>0.485382</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.388390</td>\n",
       "      <td>1.657033</td>\n",
       "      <td>0.470406</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, \n",
    "                metrics=accuracy)\n",
    "learn.fit_one_cycle(4, 1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Improving the RNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Maintaining the State of an RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel3(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)  \n",
    "        self.h_h = nn.Linear(n_hidden, n_hidden)     \n",
    "        self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
    "        self.h = 0\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for i in range(3):\n",
    "            self.h = self.h + self.i_h(x[:,i])\n",
    "            self.h = F.relu(self.h_h(self.h))\n",
    "        out = self.h_o(self.h)\n",
    "        self.h = self.h.detach()\n",
    "        return out\n",
    "    \n",
    "    def reset(self): self.h = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(328, 64, 21031)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = len(seqs)//bs\n",
    "m,bs,len(seqs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_chunks(ds, bs):\n",
    "    m = len(ds) // bs\n",
    "    new_ds = L()\n",
    "    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
    "    return new_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cut = int(len(seqs) * 0.8)\n",
    "dls = DataLoaders.from_dsets(\n",
    "    group_chunks(seqs[:cut], bs), \n",
    "    group_chunks(seqs[cut:], bs), \n",
    "    bs=bs, drop_last=True, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.677074</td>\n",
       "      <td>1.827367</td>\n",
       "      <td>0.467548</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.282722</td>\n",
       "      <td>1.870913</td>\n",
       "      <td>0.388942</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.090705</td>\n",
       "      <td>1.651793</td>\n",
       "      <td>0.462500</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.005092</td>\n",
       "      <td>1.613794</td>\n",
       "      <td>0.516587</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.965975</td>\n",
       "      <td>1.560775</td>\n",
       "      <td>0.551202</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.916182</td>\n",
       "      <td>1.595857</td>\n",
       "      <td>0.560577</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.897657</td>\n",
       "      <td>1.539733</td>\n",
       "      <td>0.574279</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.836274</td>\n",
       "      <td>1.585141</td>\n",
       "      <td>0.583173</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.805877</td>\n",
       "      <td>1.629808</td>\n",
       "      <td>0.586779</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.795096</td>\n",
       "      <td>1.651267</td>\n",
       "      <td>0.588942</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,\n",
    "                metrics=accuracy, cbs=ModelResetter)\n",
    "learn.fit_one_cycle(10, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating More Signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sl = 16\n",
    "seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
    "         for i in range(0,len(nums)-sl-1,sl))\n",
    "cut = int(len(seqs) * 0.8)\n",
    "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),\n",
    "                             group_chunks(seqs[cut:], bs),\n",
    "                             bs=bs, drop_last=True, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(#16) ['one','.','two','.','three','.','four','.','five','.'...],\n",
       " (#16) ['.','two','.','three','.','four','.','five','.','six'...]]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[L(vocab[o] for o in s) for s in seqs[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel4(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)  \n",
    "        self.h_h = nn.Linear(n_hidden, n_hidden)     \n",
    "        self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
    "        self.h = 0\n",
    "        \n",
    "    def forward(self, x):\n",
    "        outs = []\n",
    "        for i in range(sl):\n",
    "            self.h = self.h + self.i_h(x[:,i])\n",
    "            self.h = F.relu(self.h_h(self.h))\n",
    "            outs.append(self.h_o(self.h))\n",
    "        self.h = self.h.detach()\n",
    "        return torch.stack(outs, dim=1)\n",
    "    \n",
    "    def reset(self): self.h = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func(inp, targ):\n",
    "    return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.103298</td>\n",
       "      <td>2.874341</td>\n",
       "      <td>0.212565</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.231964</td>\n",
       "      <td>1.971280</td>\n",
       "      <td>0.462158</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.711358</td>\n",
       "      <td>1.813547</td>\n",
       "      <td>0.461182</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.448516</td>\n",
       "      <td>1.828176</td>\n",
       "      <td>0.483236</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.288630</td>\n",
       "      <td>1.659564</td>\n",
       "      <td>0.520671</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.161470</td>\n",
       "      <td>1.714023</td>\n",
       "      <td>0.554932</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.055568</td>\n",
       "      <td>1.660916</td>\n",
       "      <td>0.575033</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.960765</td>\n",
       "      <td>1.719624</td>\n",
       "      <td>0.591064</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.870153</td>\n",
       "      <td>1.839560</td>\n",
       "      <td>0.614665</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.808545</td>\n",
       "      <td>1.770278</td>\n",
       "      <td>0.624349</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.758084</td>\n",
       "      <td>1.842931</td>\n",
       "      <td>0.610758</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.719320</td>\n",
       "      <td>1.799527</td>\n",
       "      <td>0.646566</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.683439</td>\n",
       "      <td>1.917928</td>\n",
       "      <td>0.649821</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.660283</td>\n",
       "      <td>1.874712</td>\n",
       "      <td>0.628581</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.646154</td>\n",
       "      <td>1.877519</td>\n",
       "      <td>0.640055</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,\n",
    "                metrics=accuracy, cbs=ModelResetter)\n",
    "learn.fit_one_cycle(15, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multilayer RNNs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel5(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h = torch.zeros(n_layers, bs, n_hidden)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(res)\n",
    "    \n",
    "    def reset(self): self.h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.055853</td>\n",
       "      <td>2.591640</td>\n",
       "      <td>0.437907</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.162359</td>\n",
       "      <td>1.787310</td>\n",
       "      <td>0.471598</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.710663</td>\n",
       "      <td>1.941807</td>\n",
       "      <td>0.321777</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.520783</td>\n",
       "      <td>1.999726</td>\n",
       "      <td>0.312012</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.330846</td>\n",
       "      <td>2.012902</td>\n",
       "      <td>0.413249</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.163297</td>\n",
       "      <td>1.896192</td>\n",
       "      <td>0.450684</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.033813</td>\n",
       "      <td>2.005209</td>\n",
       "      <td>0.434814</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.919090</td>\n",
       "      <td>2.047083</td>\n",
       "      <td>0.456706</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.822939</td>\n",
       "      <td>2.068031</td>\n",
       "      <td>0.468831</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.750180</td>\n",
       "      <td>2.136064</td>\n",
       "      <td>0.475098</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.695120</td>\n",
       "      <td>2.139140</td>\n",
       "      <td>0.485433</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.655752</td>\n",
       "      <td>2.155081</td>\n",
       "      <td>0.493652</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.629650</td>\n",
       "      <td>2.162583</td>\n",
       "      <td>0.498535</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.613583</td>\n",
       "      <td>2.171649</td>\n",
       "      <td>0.491048</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.604309</td>\n",
       "      <td>2.180355</td>\n",
       "      <td>0.487874</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel5(len(vocab), 64, 2), \n",
    "                loss_func=CrossEntropyLossFlat(), \n",
    "                metrics=accuracy, cbs=ModelResetter)\n",
    "learn.fit_one_cycle(15, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exploding or Disappearing Activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building an LSTM from Scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        self.forget_gate = nn.Linear(ni + nh, nh)\n",
    "        self.input_gate  = nn.Linear(ni + nh, nh)\n",
    "        self.cell_gate   = nn.Linear(ni + nh, nh)\n",
    "        self.output_gate = nn.Linear(ni + nh, nh)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        h,c = state\n",
    "        h = torch.stack([h, input], dim=1)\n",
    "        forget = torch.sigmoid(self.forget_gate(h))\n",
    "        c = c * forget\n",
    "        inp = torch.sigmoid(self.input_gate(h))\n",
    "        cell = torch.tanh(self.cell_gate(h))\n",
    "        c = c + inp * cell\n",
    "        out = torch.sigmoid(self.output_gate(h))\n",
    "        h = outgate * torch.tanh(c)\n",
    "        return h, (h,c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        self.ih = nn.Linear(ni,4*nh)\n",
    "        self.hh = nn.Linear(nh,4*nh)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        h,c = state\n",
    "        # One big multiplication for all the gates is better than 4 smaller ones\n",
    "        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
    "        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
    "        cellgate = gates[3].tanh()\n",
    "\n",
    "        c = (forgetgate*c) + (ingate*cellgate)\n",
    "        h = outgate * c.tanh()\n",
    "        return h, (h,c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.arange(0,10); t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.chunk(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a Language Model Using LSTMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel6(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = [h_.detach() for h_ in h]\n",
    "        return self.h_o(res)\n",
    "    \n",
    "    def reset(self): \n",
    "        for h in self.h: h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.000821</td>\n",
       "      <td>2.663942</td>\n",
       "      <td>0.438314</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.139642</td>\n",
       "      <td>2.184780</td>\n",
       "      <td>0.240479</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.607275</td>\n",
       "      <td>1.812682</td>\n",
       "      <td>0.439779</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.347711</td>\n",
       "      <td>1.830982</td>\n",
       "      <td>0.497477</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.123113</td>\n",
       "      <td>1.937766</td>\n",
       "      <td>0.594401</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.852042</td>\n",
       "      <td>2.012127</td>\n",
       "      <td>0.631592</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.565494</td>\n",
       "      <td>1.312742</td>\n",
       "      <td>0.725749</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.347445</td>\n",
       "      <td>1.297934</td>\n",
       "      <td>0.711263</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.208191</td>\n",
       "      <td>1.441269</td>\n",
       "      <td>0.731201</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.126335</td>\n",
       "      <td>1.569952</td>\n",
       "      <td>0.737305</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.079761</td>\n",
       "      <td>1.427187</td>\n",
       "      <td>0.754150</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.052990</td>\n",
       "      <td>1.494990</td>\n",
       "      <td>0.745117</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.039008</td>\n",
       "      <td>1.393731</td>\n",
       "      <td>0.757894</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.031502</td>\n",
       "      <td>1.373210</td>\n",
       "      <td>0.758464</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.028068</td>\n",
       "      <td>1.368083</td>\n",
       "      <td>0.758464</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel6(len(vocab), 64, 2), \n",
    "                loss_func=CrossEntropyLossFlat(), \n",
    "                metrics=accuracy, cbs=ModelResetter)\n",
    "learn.fit_one_cycle(15, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularizing an LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dropout(Module):\n",
    "    def __init__(self, p): self.p = p\n",
    "    def forward(self, x):\n",
    "        if not self.training: return x\n",
    "        mask = x.new(*x.shape).bernoulli_(1-p)\n",
    "        return x * mask.div_(1-p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Activation Regularization and Temporal Activation Regularization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a Weight-Tied Regularized LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel7(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.drop = nn.Dropout(p)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h_o.weight = self.i_h.weight\n",
    "        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]\n",
    "        \n",
    "    def forward(self, x):\n",
    "        raw,h = self.rnn(self.i_h(x), self.h)\n",
    "        out = self.drop(raw)\n",
    "        self.h = [h_.detach() for h_ in h]\n",
    "        return self.h_o(out),raw,out\n",
    "    \n",
    "    def reset(self): \n",
    "        for h in self.h: h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),\n",
    "                loss_func=CrossEntropyLossFlat(), metrics=accuracy,\n",
    "                cbs=[ModelResetter, RNNRegularizer(alpha=2, beta=1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),\n",
    "                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.693885</td>\n",
       "      <td>2.013484</td>\n",
       "      <td>0.466634</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.685549</td>\n",
       "      <td>1.187310</td>\n",
       "      <td>0.629313</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.973307</td>\n",
       "      <td>0.791398</td>\n",
       "      <td>0.745605</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.555823</td>\n",
       "      <td>0.640412</td>\n",
       "      <td>0.794108</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.351802</td>\n",
       "      <td>0.557247</td>\n",
       "      <td>0.836100</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.244986</td>\n",
       "      <td>0.594977</td>\n",
       "      <td>0.807292</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.192231</td>\n",
       "      <td>0.511690</td>\n",
       "      <td>0.846761</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.162456</td>\n",
       "      <td>0.520370</td>\n",
       "      <td>0.858073</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.142664</td>\n",
       "      <td>0.525918</td>\n",
       "      <td>0.842285</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.128493</td>\n",
       "      <td>0.495029</td>\n",
       "      <td>0.858073</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.117589</td>\n",
       "      <td>0.464236</td>\n",
       "      <td>0.867188</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.109808</td>\n",
       "      <td>0.466550</td>\n",
       "      <td>0.869303</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.104216</td>\n",
       "      <td>0.455151</td>\n",
       "      <td>0.871826</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.100271</td>\n",
       "      <td>0.452659</td>\n",
       "      <td>0.873617</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.098121</td>\n",
       "      <td>0.458372</td>\n",
       "      <td>0.869385</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(15, 1e-2, wd=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. If the dataset for your project is so big and complicated that working with it takes a significant amount of time, what should you do?\n",
    "1. Why do we concatenate the documents in our dataset before creating a language model?\n",
    "1. To use a standard fully connected network to predict the fourth word given the previous three words, what two tweaks do we need to make to ou model?\n",
    "1. How can we share a weight matrix across multiple layers in PyTorch?\n",
    "1. Write a module that predicts the third word given the previous two words of a sentence, without peeking.\n",
    "1. What is a recurrent neural network?\n",
    "1. What is \"hidden state\"?\n",
    "1. What is the equivalent of hidden state in ` LMModel1`?\n",
    "1. To maintain the state in an RNN, why is it important to pass the text to the model in order?\n",
    "1. What is an \"unrolled\" representation of an RNN?\n",
    "1. Why can maintaining the hidden state in an RNN lead to memory and performance problems? How do we fix this problem?\n",
    "1. What is \"BPTT\"?\n",
    "1. Write code to print out the first few batches of the validation set, including converting the token IDs back into English strings, as we showed for batches of IMDb data in <<chapter_nlp>>.\n",
    "1. What does the `ModelResetter` callback do? Why do we need it?\n",
    "1. What are the downsides of predicting just one output word for each three input words?\n",
    "1. Why do we need a custom loss function for `LMModel4`?\n",
    "1. Why is the training of `LMModel4` unstable?\n",
    "1. In the unrolled representation, we can see that a recurrent neural network actually has many layers. So why do we need to stack RNNs to get better results?\n",
    "1. Draw a representation of a stacked (multilayer) RNN.\n",
    "1. Why should we get better results in an RNN if we call `detach` less often? Why might this not happen in practice with a simple RNN?\n",
    "1. Why can a deep network result in very large or very small activations? Why does this matter?\n",
    "1. In a computer's floating-point representation of numbers, which numbers are the most precise?\n",
    "1. Why do vanishing gradients prevent training?\n",
    "1. Why does it help to have two hidden states in the LSTM architecture? What is the purpose of each one?\n",
    "1. What are these two states called in an LSTM?\n",
    "1. What is tanh, and how is it related to sigmoid?\n",
    "1. What is the purpose of this code in `LSTMCell`: `h = torch.stack([h, input], dim=1)`\n",
    "1. What does `chunk` do in PyTorch?\n",
    "1. Study the refactored version of `LSTMCell` carefully to ensure you understand how and why it does the same thing as the non-refactored version.\n",
    "1. Why can we use a higher learning rate for `LMModel6`?\n",
    "1. What are the three regularization techniques used in an AWD-LSTM model?\n",
    "1. What is \"dropout\"?\n",
    "1. Why do we scale the weights with dropout? Is this applied during training, inference, or both?\n",
    "1. What is the purpose of this line from `Dropout`: `if not self.training: return x`\n",
    "1. Experiment with `bernoulli_` to understand how it works.\n",
    "1. How do you set your model in training mode in PyTorch? In evaluation mode?\n",
    "1. Write the equation for activation regularization (in math or code, as you prefer). How is it different from weight decay?\n",
    "1. Write the equation for temporal activation regularization (in math or code, as you prefer). Why wouldn't we use this for computer vision problems?\n",
    "1. What is \"weight tying\" in a language model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further Research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. In ` LMModel2`, why can `forward` start with `h=0`? Why don't we need to say `h=torch.zeros(...)`?\n",
    "1. Write the code for an LSTM from scratch (you may refer to <<lstm>>).\n",
    "1. Search the internet for the GRU architecture and implement it from scratch, and try training a model. See if you can get results similar to those we saw in this chapter. Compare you results to the results of PyTorch's built in `GRU` module.\n",
    "1. Take a look at the source code for AWD-LSTM in fastai, and try to map each of the lines of code to the concepts shown in this chapter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}