fastbook/20_CAM.ipynb

645 lines
278 KiB
Plaintext
Raw Normal View History

2020-02-28 19:44:06 +00:00
{
"cells": [
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"[[chapter_cam]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN interpretation with CAM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we know how to build up pretty much anything from scratch, let's use that knowledge to create entirely new (and very useful!) functionality: the *class activation map*. In the process, we'll learn about one handy feature of PyTorch we haven't seen before, the *hook*, and we'll apply many of the concepts classes we've learned in the rest of the book. If you want to really test out your understanding of the material in this book, after you've finished this chapter, try putting the book aside, and recreate the ideas here yourself from scratch (no peaking!)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CAM and hooks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Class Activation Mapping (or CAM) was introduced by Zhou et al. in [Learning Deep Features for Discriminative Localization](https://arxiv.org/abs/1512.04150). It uses the output of the last convolutional layer (just before our average pooling) together with the predictions to give us some heatmap visulaization of why the model made its decision.\n",
"\n",
"More precisely, at each position of our final convolutional layer we have has many filters as the last linear layer. We can then compute the dot product of those activations by the final weights to have, for each location on our feature map, the score of the feature that was used to make a decision.\n",
"\n",
"We're going to need a way to get access to the activations inside the model while it's training. In PyTorch this can be done with a *hook*. Hooks are PyTorch's equivalent of fastai's *callbacks*. However rather than allowing you to inject code to the training loop like a fastai Learner callback, hooks allow you to inject code into the forward and backward calculations themselves. We can attach a hook to any layer of the model, and it will be executed when we compute the outputs (forward hook) or during backpropagation (backward hook). A forward hook has to be a function that takes three things: a module, its input and its output, and it can perform any behavior you want. (fastai also provides a handy `HookCallback` that we won't cover here, so take a look at the fastai docs; it makes working with hooks a little easier.)\n",
"\n",
"We'll use the same cats and dogs model we trained in <<chapter_intro>>:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.181760</td>\n",
" <td>0.032238</td>\n",
" <td>0.009472</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.059119</td>\n",
" <td>0.014090</td>\n",
" <td>0.002706</td>\n",
" <td>00:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"path = untar_data(URLs.PETS)/'images'\n",
"def is_cat(x): return x[0].isupper()\n",
"dls = ImageDataLoaders.from_name_func(\n",
" path, get_image_files(path), valid_pct=0.2, seed=42,\n",
" label_func=is_cat, item_tfms=Resize(224))\n",
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we'll grab a cat picture and a batch of data:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"img = PILImage.create('images/chapter1_cat_example.jpg')\n",
"x, = first(dls.test_dl([img]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For CAM we want to store the activations of the last convolutional layer. We put our hook function in a class so it has a state that we can access later, and just store a copy of the output:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then instantiate a `Hook` and attach it to the layer we want, which is the last layer of the CNN body."
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"hook_output = Hook()\n",
"hook = learn.model[0].register_forward_hook(hook_output.hook_func)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then grab a batch and feed it through our model:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): output = learn.model.eval()(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we can access out stored activations!"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"act = hook_output.stored[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's also double-check our predictions:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[2.7374e-09, 1.0000e+00]], device='cuda:5')"
]
},
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.softmax(output, dim=-1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We know 0 is dog (for False) because the classes are automatically sorted in fastai. We can still double check by looking at `dls.vocab`: "
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [False,True]"
]
},
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.vocab"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So our model is very confident this was a picture of a cat."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To do the dot product of our weight matrix (2 by number of activations) with the activations (batch size by activations by rows by cols) we use a custom einsum:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 224, 224])"
]
},
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 7, 7])"
]
},
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)\n",
"cam_map.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For each image in our batch, and for each class, we get a 7 by 7 feature map that tells us where the activations were higher vs lower. This will let us see which area of the pictures made the model take its decision.\n",
"\n",
"For instance, the model decided this animal was a cat based on those areas (note that we need to `decode` the input `x` since it's been normalized by the `DataLoader`, and we need to cast to `TensorImage` since at the time this book is written PyTorch does not maintain types when indexing--this may be fixed by the time you are reading this):"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9S5YcSbIldkVUzdwjAFRWPfYhu/twQvY53BaHXAznHHAjXAL30q9fd1clgHA3UxXh4IqIWiAzgXr5BqxBehUSiHB3+6jJ58qVj4q744/XH68/Xv94L/3/+wL+eP3x+uP1668/lPOP1x+vf9DXH8r5x+uP1z/o6w/l/OP1x+sf9PWHcv7x+uP1D/rq33vz//0//3cnmysQWb9XVagqkul1ONzXHzgAEagIVIW/ByDC4zgQv3MIBC4CAb8jIvysKkQEZgbYgI8Tc05Mm3AzuBnMALhBwO/X+SFw4THQlMdSgUiDtAZXAKKAGPJyXRVQhYtyWaThnIL/4X/8D/h3//4/4v7yZ0AaIA1bu2PMiXMYXBtk29D2G9q2QbTB4DAzzHlgzhPwAbcJsxNmA+IGUcCV1ywqaCoQXgIAw3E88eXrZ5hNOAzudnkGDjOeQxrgc8Ld4OaAOwQONQdsAmaAgTdpDnFA4BAo13Hyj7hDZn7eAPN4f0LMoRCI++U9HsttwofBh8W5wafhgA2DGyAm8aQBGG/CJo/hHtcGBybieTjUFe78XF6/z/jbAV6q8W+AZ1UBRBDCADUpmRCJ9+LnOfl3S/kIwTQ3mFGuRQRNW31P4IAA4gDE+RwFJXOiSrmKVQC0ZFrdATfMaXymznv/T58//t//6f/5v/6Pf7VyQhoAp1DEoiEOek3ApGLVb3ONJBQOXCykcKVC1O/qkfK7yO9zUcwMNifmGDAz3pgZldgAlzQSl5Nfrk1igSgzDrjA1yMFdC2sIY1Dx23b8eHjn3C/v8JdIKIAlIrgRqMEh7tRWvIhi4TSNajz2vNh5JXZtFqDJspzxjKFbJeA0OA5LIQs10ZVKRziMAfvSRzil0WEQGJ9QHuUT4zHljS8/KzFw6USyzsjlmr9XgBivVWp4KmGHsYRgKbAhgy5AZIGOw/jF/lxlJFP405rG8oMh8a1az77q5JfjnH9IyGLKRMhDTQ6dR38IwJoyI2Icl2Rxj+NIE9MJ+OXU0vImMf6SckF/2g8K3z39QPlVKjkCan51zXgReUj8lIsIK/jVzwqbzEEPBca9Gz5xVo8WjJ6Sg9vaSEcaXUpjB4CJu8vAO9WgLdAZdZlDPIzvp4rBIKXlw94eXlF0w3DFKKtPA48BEQpJIBj2oQMwEXgblBdisrrBr8PXru5xXnDYMR6rOcosdb5t18UDvXdeiYhFPVTHEjS3BtCyHKN/KLEcd3g9RJcaCia8NlLCnx6DwquqDA+8llIBhbPVAC1UE4DKEGWvhtpVtdV824snusS6sv7sQ4qvOallN9YNpflD0JW+Nzju9D67jpEKto6LhUb5RTMDRKGMIUmZcDdkEbMvdwsrLW6x3pWP9DO7yonPUW+lsXw3/p8fTIvOhVxfWN9W9JHLQH65kX4YRC3yzfXsbTWL9QpjYOmsOFyzXmGsjb1HLno9DwGLYX98PoBt/1OmCMNghbXOinkomi9Ab3DtYXXXfdrYUTS+noYN16ilnLWBeVyABESKGEl0ovlPXh5hoUkrusWxhKhYG6ApSXn98Qv5y3NkLLydT7xQGdCAxmeNJVaQDhe6zwDkcR3JeFmepJ89uJEI1jPLpV0pufMm0lx8qUMeckS8r+cZjoJCc9HGS6kJZRr8fV8akkvDiSdA8xhYlwvd7hPhhNwaIYly/0DWGFcHV7iXGFsLe5Ufiny717f95wXV0iXng8ssHUI27o2L7jHn99DTGB5RBW+T+t6cblIC8TjmRnEjLFNYHX4+j6/Et4mYMT74wXscFrrtGR5rgvQqYeHgHr3lxf03nFOg/YVN6koRKwUT1sDWodBYB5+QYA5J+9VG9Adc0zYoACqal3KdYUStqoqWlPYXDCsBC2Rh0V89K7Kq8QWqgJ1hFKFYtpCHQvO5ZoIIahKOVWIEZ4j3KUZ5QAZ66dAL0FcNjfW1sukgMhAao3rCwEK+Ay8JDucNr2uLeUpHJvOv44fTzWMW4tQo9BLvAxkQ1WVIYFnmJOXf/FsvhxDKlw+o1L6q3sJBJJXJnG+jHcdgMtvO7l8/UA58+Yp4Aq6b43bWyIeF55KhfWGx8UCF5gU70m9t6DIWmBCGzMDxoSNyTjNF4m0LlPWcb6BC2lMCzLmBV6sKxqJIHWBicLRoNKw9R7kgwE9iBYo7dJEQHMKLaEdFVxFoSKYTjJGW4OYw0xzlUKAEtjJZUF4cSlcHrKdCppIwRGQ/+KdczkFvBYJRSChROPq4SGWupQWUmE0xIq4+2KQBC4O1Y6EpfBcA7t4NFlahmVs/HKuPF3C5NLRX4oePRct3kVpLvhp4fv6WUTiGTSufaAWSIRHoTruBmkNakQelgYo4fA7qxmXH89FJRxB0zL8Xvd+QR755YS/WAp6RYG/9vqucnY/USwgfLG0cNiwEoSKg7wRrgAUDhWYJFOrZEQzfrooE4mssGqSljVh2wxWT+tW3EmAiOgFliTrS6wfxjasFH/fBNiVD2SaY44T0joaAFMJi+hQFezbDjPH8zxgusF8QFywbzdAGmxMWkgVwjCbSGJRuqKJwsSAFiADDjQBNsGchnOe2LYdMKOZc0BjbVwErkpmWYKkEY+1Xa7CpkE02NW00sJnBCOUmsPgJoAvw5BrreXdwri4kNEl5gPCeIk7bIahoaryudp7PkGRMaotLwoPGTCYBhx2Gi43WwLs4d0mn60KgpEOLy+JbkJeECRexNgp7uIC1QbVBpMOFw294ndNgsNwB7QtpyDA1WSViMY6VXZBlQwugGFGBKRp9Ky+n+DNA02WJ37HmfwblLPMh4bb1qCdL0pSTFRYe7+cNJWwoEKSOGFJrMxjHXHFjXmEIoTscpNX/BQ22TLWiZSISkErCTZWCxILpAkQwj4CfkrboboBraNvG1rrFHQJMqi1YB4D0iAY00bDQ1bZ4IcBW+PShdW2iJtV6KWnjXpACX7y3iVQinLpMa/PxJfQ8J5SeoDlV2Jt3qUqpM7jSoJGQWPk0xfMBsOVUlpPdlgJjd95yPfQzN1D+OL6UkoviMoqRRPK6UqFuopdxdJxppBDSeN1FepQ9mR1K1bMz7mFnC0yh3+D99Jq9SONFUb6ch56W64BZTw8vmgYgsCSGf9f/pgthvd6bZVB+M7r+8p5feDJzKW7FiD8BFIxC8KYYYYyad/eH1IuOC1joVJaX8GAAWYrrxlJzaVwVxIpvh9gDxnjSOSZcvGT+Avsiaa0rNMdFgIiraHvN+y3O1rfINLgEDRRNGmEwU7vOswxzdDMAY1YVB1zTpznhCqNwHv0QS82RRYcLM+wYCmNmiLp/4UQUjlq5d+ZenEJ8iKFc4FBAWi4LGB1hS3xfng6/i/CllJIrq3oUk4Dva0n2TNRqRiF1CVnWEFYbQXLr0bZ/fJveB3ncpdIfUph51+/LvAkyQbR0zdrlxfE9WnMdlyMo6dM5XHyar/RpYK4kfPXMNAl6xcj8i6d8yMmKF4/8JxLMesk8R9akN84SXzeDFBHpEkiTrT3sSlFhBYOMEJQ4yOa48QcJ2BGDxELqMA3C4miyAtlA5H/jE+FNxCwsMGcRE5rjCvNHRMCbYrbywtuL69o2w60DpkUNgUJloxp5jz4XSFkltaxdcZkx3nCLaO/tKi8GlVF3zb4ZFY9FSwpJ7ksowqhpcX7eRwAzFmmci//iwyduFKhVGlAy4jFZ2wx1WQtLJ6FhwzT08iFeErwV/zrO7STPjBjXE+Aw3dF47LimTsAS+W7KM/lTOkd89mKMDQgQea4nr3kNIoJMsar+78Y9kx35c+JFoC0R2Fomr67nvUMaICyaOHdNQPhWZdi/vIT33/9UDnf3TCW1eJKXdhDXB9LKKII4yLXyPnx+3MmmYyozogHk1hcUigmzCY
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_dec = TensorImage(dls.train.decode((x,))[0][0])\n",
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map[0].detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So the eye and the right ear were the two main areas that made the model decide it was a picture of a cat.\n",
"\n",
"Once you're done with your hook, you should remove it otherwise it might leak some memory."
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"hook.remove()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's why it's usually a good idea to have the `Hook` class be a *context manager*, registering the hook when you enter it and removing it when you exit. A \"context manager\" is a Python construct that calls `__enter__` when the object is created in a `with` clause, and `__exit__` at the end of the `with` clause. For instance, this is how Python handles the `with open(...) as f:` construct that you'll often see for opening files in Python, and not requiring an explicit `close(f)` at the end."
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def __init__(self, m):\n",
" self.hook = m.register_forward_hook(self.hook_func) \n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That way, you can safely use it this way:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"with Hook(learn.model[0]) as hook:\n",
" with torch.no_grad(): output = learn.model.eval()(x.cuda())\n",
" act = hook.stored"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastai provides this `Hook` class for you, as well as some other handy classes to make working with hooks easier."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient CAM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The method we just saw only lets us compute a heatmap with the last activations, since once we have our features, we have to multiply them by the last weight matrix. This won't work for inner layers in the network. A variant introduced in the paper [Grad-CAM: Why did you say that? Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/abs/1611.07450) in 2016 uses the gradients of the final activation for the desired class: if you remember a little bit about the backward pass, the gradients of the output of the last layer with respect to the input of that layer is equal to the layer weights, since it is a linear layer.\n",
"\n",
2020-02-29 02:50:21 +00:00
"With deeper layers, we still want the gradients, but they won't just be equal to the weights any more. We have to calculate them. The gradients of every layer are calculated for us by PyTorch during the backward pass, but they're not stored (except for tensors where `requires_grad` is `True`). We can, however, register a hook on the *backward* pass, which PyTorch will give the gradients to as a parameter, so we can store them there. We'll use a `HookBwd` class that will work like `Hook`, but intercepts and stores gradients, instead of activations:"
2020-02-28 19:44:06 +00:00
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"class HookBwd():\n",
" def __init__(self, m):\n",
" self.hook = m.register_backward_hook(self.hook_func) \n",
" def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then for the class index 1 (for `True`, which is 'cat') we intercept the features of the last convolutional layer\n",
", as before, and compute the gradients of the output activation of our class. We can't just call `output.backward()`, because gradients only make sense with respect to a *scalar* (which is normally our *loss*), but `output` is a rank-2 tensor. But if we pick a single image (we'll use 0), and a single class (we'll use 1), then we *can* calculate the gradients of any weight or activation we like, with respect to that single value, using `output[0,cls].backward()`. Our hook intercepts the gradients that we'll use as weights."
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"cls = 1\n",
"with HookBwd(learn.model[0]) as hookg:\n",
" with Hook(learn.model[0]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The weights for our grad cam are given by the average of our gradients accross the feature map, then it's exactly the same as before:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9W5IkSZIkxiKqZu4RmVnVvTsDWsIfgAPgAiBcZ++E++AE+MX/0mIfMztVFQ83UxF8MIuqRXZ1Ru8MiDAf5d1ZmRnpbm6mKg8WlodaZuKP1x+vP17/+l7+//cN/PH64/XH6/dffyjnH68/Xv9KX38o5x+vP17/Sl9/KOcfrz9e/0pffyjnH68/Xv9KX/1H//i//a//eyITMJs/MxjMDa05MhMJAAkkAhGJYn/NDGaGZg6kAUgYHGYGpCETQCbMfF7XrQEGIPlZt4aMRI4Ddg7EGYgR/FkEciQs+RFLBzKBBCwNiZz3wO8weP0dfJ5EwAJI8DH5aqDNchyn4e/+/t/h7/7+3+F+/xkGB9Kx+w3HeeI8AwnAW8e23eC9w8wQSEQExnhgjAPIgYgDESciDlgmzIF0rZcbmhvMAS5H4PF4x8vLL4gIBAYyAzAtD4CMQETAGhAZXI+I2iSuM8Dr6/kTgJkjHTDnMw5uHtdCi5DgokYOZGgdm68vt9D1ku+OxMgBAHDXHsMoD1x5wLXu+jeuecmKw2x9twFw5/0isgRMT6C/ZiAyEXl9Zkd6vSfgabyv5DPw/nmvEYHIQGsNXv8GYAT/DUjKubuek9fkHgzAEqZn5j0bYD5ljT/hcwOGhpx7NsaYz/5//d/v/8d//G//57//Pf37oXLCGm9Cu1LKqGekQJsWJg1umA9i2geTRGXqjbX5aRKcXJtmXOWlVDYFL8ZADP6OkNBlwkKfnfdWm2DXB4FrkbKsyfwnCk29O/Ues47bbceXLz/hfn9GpkmInIqSCWQgwXs8x4lmgLUGGIXPvVGItBmZay1jULEzU+um26k1nRJXwljCwec1M7hTYTwTAUNKIee6axnMLtc0wJyfT6Oepd7PJc0lbHAZC6x9196vJTb9X7Kia8uMUzym0HJ9y4jWq8yz/pXvy4sc1bPk/A+fX3c1lWcuFObaQuueuulaCn5ecmPGdZexmTboIovz8oi1CJfvzMvnKGc00kj/IM8pJQYCn71+qJzTq1ktViAzEbk2yhK0epa4rngpJxdoKSIfzJAZtc5677KstSi1SfQKOb1DpjY48rJhtFi1mYak8Fy8/mXv1jJ+/Of5HnPD8/MXPD09o/mGczjMOgxABD2ZuaHZUooR9D5pfD4vJUggoqy/zb0tq08ZMi1PLsNmEkvdWN26X26cymRTKOZD6NkN5TWWsvKtUjfH2hMDIrmmbkRIc22M3x8XA11GwrwUbHkc+UvuibdSU8pPff0S5Q97cjUQFOpSlvK2OeXDQeeatZD17NPbXvc2l2Er762fZyGIuj4+XGbKcmYgYsAQcIvljMpIyGCnFLTW2Vpbxg2QzFwdyF++PvGcXs4My3PmtIlLXC4f0VJYPXReBKasFPS0tQjTOn282czEGENwqywUr1ufWyu/nC83c4FX3ntdfe70/GxZ2ZgWvQGgct72OxVL1p8wmvCGAty58OaIBGIKUG1mKRs9bm2lm0s5Lz7pcl/lGYWcpse9epbMFOxdhk7Ls7yE+VSWVXCSF8G5bJzCCT5rGUpIoMv7lKu1+VzmJQ25vhwmr7tCifr68ro2w53rvvNNEYT+SFsKdFHo719LGnMafDefRimkWFYCYolIv6xZIZt5h7h6RRpQobgYVD8vS177lhMZfHCu+n6GPFf//OPXp8rJ35fV/7Bx18Uq6y6Bo5EyINt6z8UrcoMogHZZknplMp6MCNgIYIwJcW3+b152KaIVjOCCXJWT1y1vdLFqE5RcnhHA0/0JvXccZ8CbT+DS3GER8jCAtwb3hkhgSMnNgDEqDmtoPTDOgTgrpnJ9/cdnZ6zNf2/NMcZVQVx3cFHMWNb7u83jdziWoGgfqRMXt2JaFzMADWZB1FTrYwZr8vAY0zzDyoPZFGCGKVkW8sM+ZS45Wt6QXMDUKKx9ywDCLjBSSKMM1XrSjypb8aW7wxVfZwRGVIxI7+YJWJM3DwAIGYy6lQX1pXPT89Ue5od70lpO57UcznQAmJH67+zZx9ePYW0uXFzewmoNL+7/O1Tyl8atYhBzKRJJIj5gk0UvKLAeKSSAOQYwBnKklJMW0bXRmQvmXMwYb+UDLNEzTa9tgp7OuC0ApCPhcG/ovSOTcBWNVtJBuJcArWhLuOUkPCwCTQZiZABB5UUEImxuTBmpBRLtctsU2IoppZ0LgmvtI0lqXD9anmGSX2YT8TAOKmGj8csP60ZlXvG+/HxdCwAcKBh9hYUhY1x7c/3cFIoSltLMeY16X36UIbvEi1kScjWqevu83sWYGAlFb9zftFC4IcchFFe8AJCIcO7lXI8P4qtb4r74/LnLaNh6o9lCIXWt6U0vEPoTBf2hcnYcyBAjlkBzp7UAWVP7bgMS8i5WHsBxaDO4QC6kIIgIQ8InnDStgAt20KKRMURW1ofXiIQYWm5RiAnm9cSYFqzWgnVLbJZo3hAjMc4BeMK8cQlD3++OfdsRkXg/Hgg0RJ5wA27bDsuOxIlsjjTDSMahZ/A7vXOdwsimzttsBmyGMQLHOLBtOxD0RJGAC8KlGdId1hqylMRScV9OGYgIuC/mb7KF7oLZjjhjWvAQmogchH3NltCZzTU2kR4yqVTiQUPtWuMEnzmkXlQLoaCCtg6kLT8RLi9jLqMbKAKvjGyK8XcTGVgGQoim4mKXsaeGFkGXgrNN698RF6vF548Lw7uY2ARj7kvgMI2hFwrLhLmj+Q4AOGLIOy9jPdXNKh6mx689sIz5jJ/Vtf9QOSko+ibTXZaHmO8wTGJAC1xWInAhI+ZNFv1XGL/iwQUnyi7jAt8s8hIiXb+fv0dUcNYmrNH6837d4aLIC0oXFBtSELcObx2wht43tN7pvdJh3uCtwxU70kjJiraKR8nMjkegb23FPYhJz7uM1IgTFUeVR5gxc23l1ZliOR5MoTE4XApQ66IryGjN/Rc5RAjpXA95Zu7WhVHNJaLzvuxCoFxvKBesnHE7CpwU/F2QltDuFMG3EAQmGuPrCtc/vufjGlj9txDGhbtYCpDy7kqhxJCjHXD4fFYqWcnnUrSZEiokgiLNnKkYdxJ8FhdPX86Bz3mF7Vb7Ydfd/cvXj2POUgKrBVo5LK8FvyxebURkIpU+8K2LDCnrqM/DEUUnZywPh6SHDSpcxImMIebZlree/8MH0olJBcFeYEKm2jcvpteA5sxLjlHxDWFO327Y73e0tsGsIZP52m4+c2LmhhgXw9GkLG44x8BxDHhTPHXJQbrTQw0jo0urb3PjJurQGvn3cjvXfHm2+cH6u95bilL7h6loAMyp2NPAAlc2tPxJQcj6u5stUk/rTtUmnMm4CGFOYHMxtyvUWK/88GcqeXnZi+KhYDD/3b+D0B+umInM84P3nWkQKSt5DTmQ75xOwCSXF1swF4+fh5dOkPuFu4xwfebCddRnLvL42esT5QQmiVM3V3da3rBuJKGEK5bnzERLfMh9IUCriZjxK68Sc49SGz3OA+M8gRhoaTC06U3cFlVPuDFB2AVm61avAm2EaAERBq0jkRgIQc+G29Mzbvdn9G0HvMEip7Lwnul1Ho935GAc17HBWke3jkTgOE4ggNEdJja1OBh3R986Y+nymvKgix7D9Lw+47/8KNkEI/O9tfFkYgtm5mVd+J7lZ+uSCVgsCFjaBBrNnEp62eMSUiyHhsvaSwyosPWjXPdBb/8dmbI+OX/+0Qxf3hGJcCnofJ5LYUQmIugtZ365DJ+ELpKIysymks48d91/5kIY5f11h3WdimO5LMtgEY0AEKv7AQUt9fmrr8/Z2mmx+J9hXHB+qc9/W7kriltZmREx4V/B3mDOgV7Bmm5Uiq6FzBxIeU6fUEYkjpk8iqC0XwRw7RQXuMz3FTLpHnifiWAABm8
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The novelty with gradCAM is that we can use it on any layer, here the output of the second to last ResNet group:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"with HookBwd(learn.model[0][-2]) as hookg:\n",
" with Hook(learn.model[0][-2]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and we can now view the activation map for this layer:"
]
},
{
"cell_type": "code",
2020-03-03 14:11:00 +00:00
"execution_count": null,
2020-02-28 19:44:06 +00:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9y3YkSZKm94momrkDiIjsmuaweYbccLbccMNn4nPwNbjkk3HJ09Mz01UZGQF3UxXhQkTUDDlVWc2oBWuRfg4CCMAvZqpy+eWXi4q78/vj98fvj7+/h/7/fQG/P35//P7484/flfP3x++Pv9PH78r5++P3x9/p43fl/P3x++Pv9PG7cv7++P3xd/rov/XH/+1/+V/d3REBEOpfUaWpIrrR9c4mdzZ/YbOdzTZ2OrfWuHXlrkoX2NXZmnBrsIvTMA4DQxgGwwUDpgvTPb4jjGkcx5PjefA8BscYHHMwbGJzYm55tY6ZYe44xmSCgAsgnhdO/m5y8xfuvrOjbKLsCrem7E24aWNXQc35D//0T/xP/8M/8eX1E7dm3JpzazvPw3g/jDkF2g79Bq3jNA53jukc4+A5BtMmwyZjHAwbmIMDQ5zp4CKgAgIGGMb3x4Off/mZMSfTjGGG4zhg7kx3bE5QwXxiZkyzeGMBR+L57iCKE78TUZo2VDsqsSCSgqDEZagIGtuNYKgIW+uI5BLm31QAJz7D4xNUlabxvvF7QBTN30l+phP3cBEt/PKv5gVY3ROg+YPUfbmv/RcFEY1rzPfrHmsl7iASa+JgboxpuDvSGoLiEtc7zZlmxLIJIi1e6/GZIpXdMMQdl7z/vAhH8woVJy9GBHUHM6bFXrnFbv7r//3v/8//6//5P/73/8/KiTSEVE4nhSM2wc81zQWLhY8vUM3Ny/8jH5+PO7V3S/FFwFkbKLnBPg2bxhyDOQczldLdUijOBRKpX8VFS/1SwMVzkeP5Tl2foBKCK0jKU+e2d94+feF2f813aojU58Z3RzGfMEdsipYQOKoNbc60yXmzAm5Mc6w5ZsRi5f56rnWq0ilpeBqeeGgusktdMYjkvsiHlc57l8ta1J7E7yQVyL0+I3a3SQinpNFo56XltV42Na9DVGrjEa/9VFR17Ugpbcl5Keh5p6XUeZ8S8iLIWsdYG0dT0SVlar1Z7jN+KrKvlYr3NU65vgjmuqaSaURjbd1j3+v68ru7pdyEY2AZgdhbR2JdluWQ2Gd+tU+/evymcoroZfEdN1vXH1aZtYgiceu1N8vCpkCdShGvtatypnKYh9BILT6kN0xFcMMtLB75vZTyusAiggK2hMfXdX8wMuvaz0XK9UcQXl7eeH15o2mPrdLYLPeJe3gUVwVRJmA2cQ8UMI0U2BAIsxL53CAv60lYdi/BqftIQ5d7WIJWxqSUcO0Bl81elvPcCM3XcMpHaeXFAAjmji5vrpz6l69P5U1VTD3U9MIWcpD3Xdeg0vLzl+VMBbua5uu9nwJfclAKWptYeyakEpyCedlvOa/9YnR0KUnI4dL5FFDxq+tJZ1Se2iyMMU6TfH18VMqoLNRilvsjYLR0HBJr+294/BXPqefCmC33Lgsj/urp1IKnjcoF8T9zLeFh/bROXhbyfI7nYrgtrV4CEoZTcuHOy1neYb3Z6enJK0usW7+IzwGmGU0E1wYIb69v7PsNT+9WqiYhWWjCHtOARodJGp24oDJEHz0gITSiOCOUUk5BrZtQlfA2dhq6JajLYjtmp6FbxkfyjkVQFC/HvPSiINpH7LP2BUHS6EgaL9X0DdNqA1kGJJ+XF7WQT7xreR9ZntI/3I8vA+Cc++nrf/lWFmvkCYWvqhNP8cs9JnpbaCiMY8Hb+LuHE3AJ5fZT9z/Id0JnJ1yhuQXUxlFlOZ+PBozlQ0mFNfUPv7+q/196/BXlPK2TpwIt+JoL8FGb0kNiYTU03+OinfU6QTKu0BLfD29V1srM8muGZ8qFWR5l+aZ8g3S9WnECnN5Vfv0ZgFy3Ov4vCcvuLy/01hlzsulGLL6F0sjMpwu9NUQ6bgIWggDKMeM52hrdHbfJDKNL04bY/KCza4UEVJTWFEZdlqLiGB4eN1GEzfzduVPrfVQUUYm/uwf6cVuW/eQTLvdOeDnVi1KpxLX4VWVy/36FOgJal/cspFB/D8VTZNni5Rvl+s6hCH55YzdQLoZhQdj/VsiXQZHgRlQU84jJy8uaC+JO0/g+zRc7autdTwN7vZZ6qEqEJb/y5GsvvAwBp4fm49dvPX4b1q5AtyDrAmacVvKCkOo2yiz4xUnlJkfgHtBC84+WUPLyDgtmmBlzTMYIYqRIn6tj/ACOLpClYFqQASHU/OraVYSmSlehi9A1YZgqW99wwKZh26kEoiexgFvA+Lo/yPhKOSzcnjZFachUjIGVt1C5uIZT0sJzlfeSdVtrd8XX5xdNVDGPpHEKbxaeoxa3wghLYmO9smJV14/EygfvF7+7+LN1ybb0JDz1Kail4Lnm67uk18odLxj8K3E9w5a4cMMTPi88cu45l/dOmK3S0NbCSLmFz7LYRynZbS0MgRuGnc7DS0lPcYqt0hTnXGMtl1vXUCghjNeH9fLznv4tZbO/qZzdn4skwD2skCqOM4chLQJlzXVXVxoNlU6TRtcG6rgKNMVVsDQoJuAohoTwG0yCILkG8nPOIE9ccQ+4GTHfSRB4KauwLN1AkvktwCOoOHvixDYNGwOIjVwCiNBU2LcdM+PxPNibcPigO9x7A+m4WFKbGuwyk+eEYWAtGD4Tw9vFUva4+TI2bdthGuaXjfVEKaJoa6FMOCZFN9ipFGKozPSMsgQzwhHFXLDhGAnGJYRjETGaHykSaiWpnClIFDSEBWcrXgu9Kk9U65tMpSWCUcKQpPK72vpgM2Ne4kqReJ3ZxdhkXF4aIqoLCYnm54mXJadII5VG04Zpx+UkombuiZnF+6QHdkoeZTmXcirisc3TC8oqoh0Hhlkoa5MPIcbSUSTXyE9Hl3wJ/wYF/U3l9KulDDcDGoupcFr2slbaUGso8fuydFIQcxnBeM+Aq+nyyytAmtSySCcRhBdaX1dYfjYXPL2kOkgtpuApcUrQ8tqERqN5XNeYRkOQ3mmto63Ttz1+TlKkaaM1WcSHqqKEctIajiLp2Y9paD8RQnl8cDStuttJbi2vmGsuCfFOMoYTxl2clqTAnuhh/SXjnfMPUuZfNG2KJkJIb58GSstzL2iankSCFz2vtCTQT47nYigFcJczrl2EoAMjBNlKmeoeinVNVHJe/ILJZQsWl/UrbCvJDdT1FLu6UFiue9AYFuuXD1W9XGfJF4slLzTigFt4Tm1t7cGUmV74/Drl2z94+tq/33r8pnI2vS/LInnjmjfjYgidLjc6Nzo7nZ1NOxuaeR3oe2NTYVPoEhi/GN1gBR2xk0BSYC5cG7lMbAZ8zM1pEj9V3ilSG3LCEBeEHnYw3YQnjO5M1JROZ9ONTRsttiY2uzW2243b7Ubv2yJCVDQ+VwC3gLYJhyJ2DI/blEj1jHiOxh3RmYEyutBQfJSohfUtT+OllSoMEXoTbIb0BYiFis8svX2ti+d9OHGvM12AyoUhTOPaRWnS0IS1rUINHEl4pxKhTNJK8bOcCre2aa27J3UqZ2RBOreLXNVrVsTDx3Rd7G1GflLG4fyM2FMDP/OaxRyfYbDjcxK4oRSVZPFD5sBpudyoRjws5TgiB23utILTlwsx8cz3C9o0PXwmdjxypdebX2u07inep/2Ggv42rJXbhxqi801Ds0o5GxvdN7o2Oo3gLg0hFKppxHJdWZtfimgSQhAQp6AJYYHmAXaAG62stoTXLkhi7phkcJ5xTdl3KziVhEJjoxMLKS70FnFlk9p0QXvj/vLC/eWVvm/01hGbNAFF0VRMFcHmEdftIB26NmRv8DQexxGK0XLDxBA1uju9C3LrjDEZ6qelChcGCkNgqrA3xVTB5oKQRXS1lPTiL0USJSRdr4Bp5tTSABRMLTZYU3hUIsjIj8+fJX+eK4Qo6K+JjPyD2pUAn97Vz4AULbWX+LxSnDO2DXiJexZ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. What is a hook in PyTorch?\n",
"1. Which layer does CAM use the outputs of?\n",
"1. Why does CAM require a hook?\n",
"1. Look at the source code of `ActivationStats` class and see how it uses hooks.\n",
"1. Write a hook that stores the activation of a given layer in a model (without peaking, if possible).\n",
"1. Why do we call `eval` before getting the activations? Why do we use `no_grad`?\n",
"1. Use `torch.einsum` to compute the \"dog\" or \"cat\" score of each of the locations in the last activation of the body of the model.\n",
"1. How do you check which orders the categories are in (i.e. the correspondence of index->category)?\n",
"1. Why are we using `decode` when displaying the input image?\n",
"1. What is a \"context manager\"? What special methods need to be defined to create one?\n",
"1. Why can't we use plain CAM for the inner layers of a network?\n",
"1. Why do we need to hook the backward pass in order to do GradCAM?\n",
"1. Why can't we call `output.backward()` when `output` is a rank-2 tensor of output activations per image per class?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Try removing `keepdim` and see what happens. Look up this parameter in the PyTorch docs. Why do we need it in this notebook?\n",
"1. Create a notebook like this one, but for NLP, and use it to find which words in a movie review are most significant in assessing sentiment of a particular movie review."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}