{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Application architectures deep dive"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computer vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cnn_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'cut': -2,\n",
       " 'split': <function fastai2.vision.learner._resnet_split(m)>,\n",
       " 'stats': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_meta[resnet50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): full: False\n",
       "  (2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=20, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "create_head(20,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### unet_learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Siamese network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai2.vision.all import *\n",
    "path = untar_data(URLs.PETS)\n",
    "files = get_image_files(path/\"images\")\n",
    "\n",
    "class SiameseImage(Tuple):\n",
    "    def show(self, ctx=None, **kwargs): \n",
    "        img1,img2,same_breed = self\n",
    "        if not isinstance(img1, Tensor):\n",
    "            if img2.size != img1.size: img2 = img2.resize(img1.size)\n",
    "            t1,t2 = tensor(img1),tensor(img2)\n",
    "            t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)\n",
    "        else: t1,t2 = img1,img2\n",
    "        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)\n",
    "        return show_image(torch.cat([t1,line,t2], dim=2), \n",
    "                          title=same_breed, ctx=ctx)\n",
    "    \n",
    "def label_func(fname):\n",
    "    return re.match(r'^(.*)_\\d+.jpg$', fname.name).groups()[0]\n",
    "\n",
    "class SiameseTransform(Transform):\n",
    "    def __init__(self, files, label_func, splits):\n",
    "        self.labels = files.map(label_func).unique()\n",
    "        self.lbl2files = {l: L(f for f in files if label_func(f) == l) for l in self.labels}\n",
    "        self.label_func = label_func\n",
    "        self.valid = {f: self._draw(f) for f in files[splits[1]]}\n",
    "        \n",
    "    def encodes(self, f):\n",
    "        f2,t = self.valid.get(f, self._draw(f))\n",
    "        img1,img2 = PILImage.create(f),PILImage.create(f2)\n",
    "        return SiameseImage(img1, img2, t)\n",
    "    \n",
    "    def _draw(self, f):\n",
    "        same = random.random() < 0.5\n",
    "        cls = self.label_func(f)\n",
    "        if not same: cls = random.choice(L(l for l in self.labels if l != cls)) \n",
    "        return random.choice(self.lbl2files[cls]),same\n",
    "    \n",
    "splits = RandomSplitter()(files)\n",
    "tfm = SiameseTransform(files, label_func, splits)\n",
    "tls = TfmdLists(files, tfm, splits=splits)\n",
    "dls = tls.dataloaders(after_item=[Resize(224), ToTensor], \n",
    "    after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SiameseModel(Module):\n",
    "    def __init__(self, encoder, head):\n",
    "        self.encoder,self.head = encoder,head\n",
    "    \n",
    "    def forward(self, x1, x2):\n",
    "        ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)\n",
    "        return self.head(ftrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = create_body(resnet34, cut=-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "head = create_head(512*4, 2, ps=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SiameseModel(encoder, head)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func(out, targ):\n",
    "    return nn.CrossEntropyLoss()(out, targ.long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def siamese_splitter(model):\n",
    "    return [params(model.encoder), params(model.head)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, model, loss_func=loss_func, \n",
    "                splitter=siamese_splitter, metrics=accuracy)\n",
    "learn.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.367015</td>\n",
       "      <td>0.281242</td>\n",
       "      <td>0.885656</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.307688</td>\n",
       "      <td>0.214721</td>\n",
       "      <td>0.915426</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.275221</td>\n",
       "      <td>0.170615</td>\n",
       "      <td>0.936401</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.223771</td>\n",
       "      <td>0.159633</td>\n",
       "      <td>0.943843</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(4, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.212744</td>\n",
       "      <td>0.159033</td>\n",
       "      <td>0.944520</td>\n",
       "      <td>00:35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.201893</td>\n",
       "      <td>0.159615</td>\n",
       "      <td>0.942490</td>\n",
       "      <td>00:35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.204606</td>\n",
       "      <td>0.152338</td>\n",
       "      <td>0.945196</td>\n",
       "      <td>00:36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.213203</td>\n",
       "      <td>0.148346</td>\n",
       "      <td>0.947903</td>\n",
       "      <td>00:36</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(4, slice(1e-6,1e-4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Natural language processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tabular"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wrapping up architectures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further research"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}