fastbook/clean/18_CAM.ipynb

485 lines
270 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN interpretation with CAM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CAM and hooks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-04-15 13:05:34 +00:00
" <td>0.141987</td>\n",
" <td>0.018823</td>\n",
" <td>0.007442</td>\n",
" <td>00:16</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-04-15 13:05:34 +00:00
" <td>0.050934</td>\n",
" <td>0.015366</td>\n",
" <td>0.006766</td>\n",
" <td>00:21</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"path = untar_data(URLs.PETS)/'images'\n",
"def is_cat(x): return x[0].isupper()\n",
"dls = ImageDataLoaders.from_name_func(\n",
2020-04-15 13:05:34 +00:00
" path, get_image_files(path), valid_pct=0.2, seed=21,\n",
2020-03-06 18:19:03 +00:00
" label_func=is_cat, item_tfms=Resize(224))\n",
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img = PILImage.create('images/chapter1_cat_example.jpg')\n",
"x, = first(dls.test_dl([img]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hook_output = Hook()\n",
"hook = learn.model[0].register_forward_hook(hook_output.hook_func)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): output = learn.model.eval()(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"act = hook_output.stored[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-15 13:05:34 +00:00
"tensor([[7.3566e-07, 1.0000e+00]], device='cuda:0')"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.softmax(output, dim=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [False,True]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 224, 224])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 7, 7])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)\n",
"cam_map.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-04-15 13:05:34 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9S5IkSZIl9phFVM3dI7KquwhNmKFZAAQCroMLzAFwIyxwB+xxpsag0V2Z4eFmKsKMBT9mEYusjKjJWaAXaUme7uFuH1UR/jx+/BFxd/zx+OPxx+Pf30P//76APx5/PP54/O3HH8r5x+OPx7/Txx/K+cfjj8e/08cfyvnH44/Hv9PHH8r5x+OPx7/TR//eH//P//U/uzsgAP8XP6gKVBXJ9DoAd4e71e8EAhEBROAeT3IIRAB3wMxgDoR98O1v/CARAAqzCR8DNgfmGBjD4GYwc9g0wOK1cPDzwfd1QNY1iEhclch2zcbrjmsCAIMCUDgUjwv47/7pP+Cf/uk/4Pb657hWVzQ5MIbhGhMOQNqB3k9o74AK33Nizgs2LwAD4iN+tgsChyogCpgbWgOaClpD/B6G6/rA1y8/w9wATAAGEYdKXL2bwX1CVeBwwCfcJ2JFBRCFeNwnoNwHgYjGh6DBuTexZ7LtHWCIjZruUBGoNuSyxnJLPC8WEuYG95CLXGvjfgjiMwVSwuTw9XkigCxZivWJPTV3SAjPEkGP15ob1yfWUkSg/GyIQx2AW4oCRITvZTCLvW9N0ZTrAod5rKUAEBV01SWO7mgSFyJiUFj9Pn6n8T4iIZNcd4FAEfphc2Lys+HA//Tz5//jf/m//vf/7W/p33eVE9LiZlKjuEjfZl9SeQUCy98JbygVB+tF+Xp3h8gSiF15kN/hochzYo6JOSwWl8rIS4JsFyUAFyffKt7XnYYCoZDghvGT+WnxPJWO2+3Ap88/4fbyVoItQqPEr9h/w7CBNgVAq+VSVQgUcwJuqSkK+IRN43Ua4CFQuWEhqEsa8/fuTiHMe1KoCBxGoxYCJiWgIei80fiutTFxnbliInADjV1cUgizcl1yTb3WPC4k15yCqGsPxZTXSqHl1cTa8b1qz2T/hNonlbWPy4p6fbRS0SUvkmsnHmsuaQQ8rzeeuxQ1L8PjtfXc5WDiKy/KkEogkjLgEBe4UB73/RRdDmLfhyX0v/n4rnKGFaiVCUsIDxHwtbRpJ1y8cHIqpxk9U21qSG4Kq8O5ebJZnFgYR77e4Pk9vbOBFvTZWJRByL36dgXyNZLWeP09FJOGUAVvb5/w9voJrXVcU6HS4rp4LaKCJgKjAE2b3LzwPWF06QEsBEa4SWH5QzLCaIQQUS6WUORGUyjcAeW+0B5/Yyzlmx9TUWmgsCmlLEOW62Bl7dIY7e8oT58n/KOIQvO1iLVDqTUVFlROd66QU4gda9e3PYI/rcOO0tJfqAhREiBucFmyIPWVFs5oMEG0Idv7+WY0fHtt/q2sJB2CQWRClQ5CqIQGGCyvHOaJCAS9pdFexvFH2vkDz6lINfE0MUsjd8O3ryyfJ3k/taBPHm57fXoC4Pkt3QGbk8qQFnBZtvXk/MUmj/v11PNkCTIFIzc0lcV8Ldzb2yec5wvMQtRElK+bdEIC0YamHS4K84BhDofKpjT0KJj7Padl/Y2llwwdEhpS1NJF8HPMDJv41LqJOK9Z4LL2Ir57CU0umlNbyyAXRMOy+EaDsim97ErMaymVUKn3lG3tU6I0N7+uZFNAswoRch890Vu+gEr6JAf5O4lraELPb/EJyn2DONyVHnB56ny9bO/nVHxa2QgnxGD0pvlcl7ULRpna/52XSt/7I8f5A+UsYYjllM0SSH2Uf3Mj9AZwwPUJuMQ3WlvT8CwVi3CTNqSRwmdmcJuMszyvrGAKICVrBenyKnLRIU+LBV8G4fkKV3z6+vqK3jquaZCmyDhLVWAWiqgq0B7BohkwNthlcwBwqCqAHjHSnFAArWncz/Nq01sEJG7aMOfYhE1WzEkUMc2oK47n24n70NYoKLsMb8Y2IE+667iGVM5NAVUVJis4CXShFEgpr18fQfe8diP3IXV7ueRngxryZEZ0JEYkmxCJu+DPe7e/V8JRFUVrRDtcL0FAZSdEVW0AkntQGjWp7/HRvsFaFLJQCSiboZ/s9/2EWpYsOjHQWo3ffnwf1sLW/vkzzFgWY3uk0+TFeVlnAQhtJDEVIV/GcL59Tnx4QF8zw5wTPgxzGox3pvk+LnTUpVr0+DQWu2Ly+jKKCmFXuIQRUcvYoUFV0XuHA5jT0DS8hoIeDVROCqWKQhqgvP6mwMMn4KEgYg63gMRGkuUbbVrrnvFkU+CqzdjchMOFXtMt9397Gr2VprcvVaTXlBXvp2Kmd9ZlnLgNeCJaGLfYZvo9ld1BSFs3km4M5fnqPhDKtyT5SYbqnxtiqmd9u2z5GSl4VEyRBtUO0Yjz5zR867ekNX6mVciZaxjr6s+LC4l1VY/3hUJyffdY3/d7zxtD3cvfUzb7XeVUn1gsbFjz8AIOmxMZx8hufREQRwFIU1wW7l4kGNDd16ZldqelRFpfL2s3zWCT1jMlw50wYcWI6eUNSfw4ZnoMIQ/ijq7Bzo05MecE3Bk7hOKk/zqPE2bA4/GAo2O4QQ1ox1mYIT2HucNsYFiwm601xndGVpaxXhOIK3xOPK4Lx3HAfaIcA9ZXeM8GSKxNekwjQygq8GkhJCXNUgoZRk8wBllxybWK63UB1GnHACwWd2GiQicOGoKFTBReED51IrwtBXHzHi4MS5LpBMJIBSdcX+7JQlvB/p1BTs9eniz/DS9PGga3QbVBpMGIqtwl5K9i/XiukVQsT8x1zluPj5VSqtCBBhHH9PC80oDiMhKpUS/caLisLNlCgD/Qzx94zgV9hAIeuinhGhJKQpbFdI+Q2FFCUe/k6cXibY2Ln8KwPpeWFgAy3WFWr9ufw6WA06VGjJU7vtg2Le8QXqs3LZgyRwTx0hqaHIB0HMeJ1jpEG8QF0rgpqnBzwiFAm0IamUrCVvOJcKNUYJ8k0wIST1fYzPvLm44LrIChYB9hEePE3WlESJj/rdcG7EtPWRihDLnC4RKGVlTX3z3pvCBr9lAhmHjiEMJb8e3Sub8l+JrKknsUgC50IVCRMxBTKLTuLC7SbG7ykJ4sCaW4MC0vxcWAky1O8inCfJDECRKSPwMADA1SBHbExl5f5SrMAVjJEwCYpxEgwikDZrUGblYyvu6En7XDnd94fFc5PW++3jQp8QIqtTs7sysentXMob3T8tB6pcfj8yBABUW0wqT+4Bb5TbOJZNvSQ6aBAz1iQSsymIDXtaZxESyBAgStNQCCQasHD2Xr5w3n7QWtExLNgHVNWgm4aHhAm5F/lJb5X8E0w3VdaLoUyy2FVtFcYSKbBd2UCqmn8ZkqUuymbNoc95LChEIwoWRpDJcHDD1XOsJYpxRieBheTxKrXEBqHRiioOB4pXCwxfLuwUpnyLEEpRDTir2eJY0vR8LfVNxUyrQA+bpifWVbs+WDqSRMV9EohyPIvGkoqnBPd+iuyTtnWqlkX7brNKKUxtCIhtPi9VaxHa/aHMWPFGT+/uMHbO2mSEjLtSDGU2yAsgu1hBGjUThEYLVwvn8EN8wW4KdXnSw8sGmhyL5EOITRIwaVjR4vS0pqfzl0JNlhMxZPW8QkhihoMBIot5dX3F5f0Y8TIi1Il1QWwndVwXzMgJYONBzQ1iJOHY553eGMs2OzI9YJDkyB3uEWsFp4gQIywEikEsouJrx3K42L3JrvC782nj8rEvtQoZAohqHG0qmlaNv2J1wMQyokwFaRQQZp5TnBNA9Qr4s0rmNtBUMcydgP5TVlmfuSjf37/rObEcZtz9v2GiDagpdyRl4Y5Szggfg0lWVbu538UZVN+RPxJF4BCxc2a1RruwgmV/uhMn77+EGec6kbrzOUiBsTcVUS+U9SgoIn054S03AqghsXJxPVK7YNC2Sb10R5jRTaeLu4eU2DyesyeliVjEdQHiYuISx8CKmH0dAGbQfOlze8fPqMl7dP6Oc
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_dec = TensorImage(dls.train.decode((x,))[0][0])\n",
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
2020-04-15 13:05:34 +00:00
"ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
2020-03-06 18:19:03 +00:00
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def __init__(self, m):\n",
" self.hook = m.register_forward_hook(self.hook_func) \n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with Hook(learn.model[0]) as hook:\n",
" with torch.no_grad(): output = learn.model.eval()(x.cuda())\n",
" act = hook.stored"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient CAM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class HookBwd():\n",
" def __init__(self, m):\n",
" self.hook = m.register_backward_hook(self.hook_func) \n",
" def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cls = 1\n",
"with HookBwd(learn.model[0]) as hookg:\n",
" with Hook(learn.model[0]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9W5IkSZIkxiKqZu4RmVnVvTsDWsIfgAPgAiBcZ++E++AE+MX/0mIfMztVFQ83UxF8MIuqRXZ1Ru8MiDAf5d1ZmRnpbm6mKg8WlodaZuKP1x+vP17/+l7+//cN/PH64/XH6/dffyjnH68/Xv9KX38o5x+vP17/Sl9/KOcfrz9e/0pffyjnH68/Xv9KX/1H//i//a//eyITMJs/MxjMDa05MhMJAAkkAhGJYn/NDGaGZg6kAUgYHGYGpCETQCbMfF7XrQEGIPlZt4aMRI4Ddg7EGYgR/FkEciQs+RFLBzKBBCwNiZz3wO8weP0dfJ5EwAJI8DH5aqDNchyn4e/+/t/h7/7+3+F+/xkGB9Kx+w3HeeI8AwnAW8e23eC9w8wQSEQExnhgjAPIgYgDESciDlgmzIF0rZcbmhvMAS5H4PF4x8vLL4gIBAYyAzAtD4CMQETAGhAZXI+I2iSuM8Dr6/kTgJkjHTDnMw5uHtdCi5DgokYOZGgdm68vt9D1ku+OxMgBAHDXHsMoD1x5wLXu+jeuecmKw2x9twFw5/0isgRMT6C/ZiAyEXl9Zkd6vSfgabyv5DPw/nmvEYHIQGsNXv8GYAT/DUjKubuek9fkHgzAEqZn5j0bYD5ljT/hcwOGhpx7NsaYz/5//d/v/8d//G//57//Pf37oXLCGm9Cu1LKqGekQJsWJg1umA9i2geTRGXqjbX5aRKcXJtmXOWlVDYFL8ZADP6OkNBlwkKfnfdWm2DXB4FrkbKsyfwnCk29O/Ues47bbceXLz/hfn9GpkmInIqSCWQgwXs8x4lmgLUGGIXPvVGItBmZay1jULEzU+um26k1nRJXwljCwec1M7hTYTwTAUNKIee6axnMLtc0wJyfT6Oepd7PJc0lbHAZC6x9196vJTb9X7Kia8uMUzym0HJ9y4jWq8yz/pXvy4sc1bPk/A+fX3c1lWcuFObaQuueuulaCn5ecmPGdZexmTboIovz8oi1CJfvzMvnKGc00kj/IM8pJQYCn71+qJzTq1ktViAzEbk2yhK0epa4rngpJxdoKSIfzJAZtc5677KstSi1SfQKOb1DpjY48rJhtFi1mYak8Fy8/mXv1jJ+/Of5HnPD8/MXPD09o/mGczjMOgxABD2ZuaHZUooR9D5pfD4vJUggoqy/zb0tq08ZMi1PLsNmEkvdWN26X26cymRTKOZD6NkN5TWWsvKtUjfH2hMDIrmmbkRIc22M3x8XA11GwrwUbHkc+UvuibdSU8pPff0S5Q97cjUQFOpSlvK2OeXDQeeatZD17NPbXvc2l2Er762fZyGIuj4+XGbKcmYgYsAQcIvljMpIyGCnFLTW2Vpbxg2QzFwdyF++PvGcXs4My3PmtIlLXC4f0VJYPXReBKasFPS0tQjTOn282czEGENwqywUr1ufWyu/nC83c4FX3ntdfe70/GxZ2ZgWvQGgct72OxVL1p8wmvCGAty58OaIBGIKUG1mKRs9bm2lm0s5Lz7pcl/lGYWcpse9epbMFOxdhk7Ls7yE+VSWVXCSF8G5bJzCCT5rGUpIoMv7lKu1+VzmJQ25vhwmr7tCifr68ro2w53rvvNNEYT+SFsKdFHo719LGnMafDefRimkWFYCYolIv6xZIZt5h7h6RRpQobgYVD8vS177lhMZfHCu+n6GPFf//OPXp8rJ35fV/7Bx18Uq6y6Bo5EyINt6z8UrcoMogHZZknplMp6MCNgIYIwJcW3+b152KaIVjOCCXJWT1y1vdLFqE5RcnhHA0/0JvXccZ8CbT+DS3GER8jCAtwb3hkhgSMnNgDEqDmtoPTDOgTgrpnJ9/cdnZ6zNf2/NMcZVQVx3cFHMWNb7u83jdziWoGgfqRMXt2JaFzMADWZB1FTrYwZr8vAY0zzDyoPZFGCGKVkW8sM+ZS45Wt6QXMDUKKx9ywDCLjBSSKMM1XrSjypb8aW7wxVfZwRGVIxI7+YJWJM3DwAIGYy6lQX1pXPT89Ue5od70lpO57UcznQAmJH67+zZx9ePYW0uXFzewmoNL+7/O1Tyl8atYhBzKRJJIj5gk0UvKLAeKSSAOQYwBnKklJMW0bXRmQvmXMwYb+UDLNEzTa9tgp7OuC0ApCPhcG/ovSOTcBWNVtJBuJcArWhLuOUkPCwCTQZiZABB5UUEImxuTBmpBRLtctsU2IoppZ0LgmvtI0lqXD9anmGSX2YT8TAOKmGj8csP60ZlXvG+/HxdCwAcKBh9hYUhY1x7c/3cFIoSltLMeY16X36UIbvEi1kScjWqevu83sWYGAlFb9zftFC4IcchFFe8AJCIcO7lXI8P4qtb4r74/LnLaNh6o9lCIXWt6U0vEPoTBf2hcnYcyBAjlkBzp7UAWVP7bgMS8i5WHsBxaDO4QC6kIIgIQ8InnDStgAt20KKRMURW1ofXiIQYWm5RiAnm9cSYFqzWgnVLbJZo3hAjMc4BeMK8cQlD3++OfdsRkXg/Hgg0RJ5wA27bDsuOxIlsjjTDSMahZ/A7vXOdwsimzttsBmyGMQLHOLBtOxD0RJGAC8KlGdId1hqylMRScV9OGYgIuC/mb7KF7oLZjjhjWvAQmogchH3NltCZzTU2kR4yqVTiQUPtWuMEnzmkXlQLoaCCtg6kLT8RLi9jLqMbKAKvjGyK8XcTGVgGQoim4mKXsaeGFkGXgrNN698RF6vF548Lw7uY2ARj7kvgMI2hFwrLhLmj+Q4AOGLIOy9jPdXNKh6mx689sIz5jJ/Vtf9QOSko+ibTXZaHmO8wTGJAC1xWInAhI+ZNFv1XGL/iwQUnyi7jAt8s8hIiXb+fv0dUcNYmrNH6837d4aLIC0oXFBtSELcObx2wht43tN7pvdJh3uCtwxU70kjJiraKR8nMjkegb23FPYhJz7uM1IgTFUeVR5gxc23l1ZliOR5MoTE4XApQ66IryGjN/Rc5RAjpXA95Zu7WhVHNJaLzvuxCoFxvKBesnHE7CpwU/F2QltDuFMG3EAQmGuPrCtc/vufjGlj9txDGhbtYCpDy7kqhxJCjHXD4fFYqWcnnUrSZEiokgiLNnKkYdxJ8FhdPX86Bz3mF7Vb7Ydfd/cvXj2POUgKrBVo5LK8FvyxebURkIpU+8K2LDCnrqM/DEUUnZywPh6SHDSpcxImMIebZlree/8MH0olJBcFeYEKm2jcvpteA5sxLjlHxDWFO327Y73e0tsGsIZP52m4+c2LmhhgXw9GkLG44x8BxDHhTPHXJQbrTQw0jo0urb3PjJurQGvn3cjvXfHm2+cH6u95bilL7h6loAMyp2NPAAlc2tPxJQcj6u5stUk/rTtUmnMm4CGFOYHMxtyvUWK/88GcqeXnZi+KhYDD/3b+D0B+umInM84P3nWkQKSt5DTmQ75xOwCSXF1swF4+fh5dOkPuFu4xwfebCddRnLvL42esT5QQmiVM3V3da3rBuJKGEK5bnzERLfMh9IUCriZjxK68Sc49SGz3OA+M8gRhoaTC06U3cFlVPuDFB2AVm61avAm2EaAERBq0jkRgIQc+G29Mzbvdn9G0HvMEip7Lwnul1Ho935GAc17HBWke3jkTgOE4ggNEdJja1OBh3R986Y+nymvKgix7D9Lw+47/8KNkEI/O9tfFkYgtm5mVd+J7lZ+uSCVgsCFjaBBrNnEp62eMSUiyHhsvaSwyosPWjXPdBb/8dmbI+OX/+0Qxf3hGJcCnofJ5LYUQmIugtZ365DJ+ELpKIysymks48d91/5kIY5f11h3WdimO5LMtgEY0AEKv7AQUt9fmrr8/Z2mmx+J9hXHB+qc9/W7kriltZmREx4V/B3mDOgV7Bmm5Uiq6FzBxIeU6fUEYkjpk8iqC0XwRw7RQXuMz3FTLpHnifiWAABm8
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with HookBwd(learn.model[0][-2]) as hookg:\n",
" with Hook(learn.model[0][-2]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9y3YkSZKm94momrkDiIjsmuaweYbccLbccMNn4nPwNbjkk3HJ09Mz01UZGQF3UxXhQkTUDDlVWc2oBWuRfg4CCMAvZqpy+eWXi4q78/vj98fvj7+/h/7/fQG/P35//P7484/flfP3x++Pv9PH78r5++P3x9/p43fl/P3x++Pv9PG7cv7++P3xd/rov/XH/+1/+V/d3REBEOpfUaWpIrrR9c4mdzZ/YbOdzTZ2OrfWuHXlrkoX2NXZmnBrsIvTMA4DQxgGwwUDpgvTPb4jjGkcx5PjefA8BscYHHMwbGJzYm55tY6ZYe44xmSCgAsgnhdO/m5y8xfuvrOjbKLsCrem7E24aWNXQc35D//0T/xP/8M/8eX1E7dm3JpzazvPw3g/jDkF2g79Bq3jNA53jukc4+A5BtMmwyZjHAwbmIMDQ5zp4CKgAgIGGMb3x4Off/mZMSfTjGGG4zhg7kx3bE5QwXxiZkyzeGMBR+L57iCKE78TUZo2VDsqsSCSgqDEZagIGtuNYKgIW+uI5BLm31QAJz7D4xNUlabxvvF7QBTN30l+phP3cBEt/PKv5gVY3ROg+YPUfbmv/RcFEY1rzPfrHmsl7iASa+JgboxpuDvSGoLiEtc7zZlmxLIJIi1e6/GZIpXdMMQdl7z/vAhH8woVJy9GBHUHM6bFXrnFbv7r//3v/8//6//5P/73/8/KiTSEVE4nhSM2wc81zQWLhY8vUM3Ny/8jH5+PO7V3S/FFwFkbKLnBPg2bxhyDOQczldLdUijOBRKpX8VFS/1SwMVzkeP5Tl2foBKCK0jKU+e2d94+feF2f813aojU58Z3RzGfMEdsipYQOKoNbc60yXmzAm5Mc6w5ZsRi5f56rnWq0ilpeBqeeGgusktdMYjkvsiHlc57l8ta1J7E7yQVyL0+I3a3SQinpNFo56XltV42Na9DVGrjEa/9VFR17Ugpbcl5Keh5p6XUeZ8S8iLIWsdYG0dT0SVlar1Z7jN+KrKvlYr3NU65vgjmuqaSaURjbd1j3+v68ru7pdyEY2AZgdhbR2JdluWQ2Gd+tU+/evymcoroZfEdN1vXH1aZtYgiceu1N8vCpkCdShGvtatypnKYh9BILT6kN0xFcMMtLB75vZTyusAiggK2hMfXdX8wMuvaz0XK9UcQXl7eeH15o2mPrdLYLPeJe3gUVwVRJmA2cQ8UMI0U2BAIsxL53CAv60lYdi/BqftIQ5d7WIJWxqSUcO0Bl81elvPcCM3XcMpHaeXFAAjmji5vrpz6l69P5U1VTD3U9MIWcpD3Xdeg0vLzl+VMBbua5uu9nwJfclAKWptYeyakEpyCedlvOa/9YnR0KUnI4dL5FFDxq+tJZ1Se2iyMMU6TfH18VMqoLNRilvsjYLR0HBJr+294/BXPqefCmC33Lgsj/urp1IKnjcoF8T9zLeFh/bROXhbyfI7nYrgtrV4CEoZTcuHOy1neYb3Z6enJK0usW7+IzwGmGU0E1wYIb69v7PsNT+9WqiYhWWjCHtOARodJGp24oDJEHz0gITSiOCOUUk5BrZtQlfA2dhq6JajLYjtmp6FbxkfyjkVQFC/HvPSiINpH7LP2BUHS6EgaL9X0DdNqA1kGJJ+XF7WQT7xreR9ZntI/3I8vA+Cc++nrf/lWFmvkCYWvqhNP8cs9JnpbaCiMY8Hb+LuHE3AJ5fZT9z/Id0JnJ1yhuQXUxlFlOZ+PBozlQ0mFNfUPv7+q/196/BXlPK2TpwIt+JoL8FGb0kNiYTU03+OinfU6QTKu0BLfD29V1srM8muGZ8qFWR5l+aZ8g3S9WnECnN5Vfv0ZgFy3Ov4vCcvuLy/01hlzsulGLL6F0sjMpwu9NUQ6bgIWggDKMeM52hrdHbfJDKNL04bY/KCza4UEVJTWFEZdlqLiGB4eN1GEzfzduVPrfVQUUYm/uwf6cVuW/eQTLvdOeDnVi1KpxLX4VWVy/36FOgJal/cspFB/D8VTZNni5Rvl+s6hCH55YzdQLoZhQdj/VsiXQZHgRlQU84jJy8uaC+JO0/g+zRc7autdTwN7vZZ6qEqEJb/y5GsvvAwBp4fm49dvPX4b1q5AtyDrAmacVvKCkOo2yiz4xUnlJkfgHtBC84+WUPLyDgtmmBlzTMYIYqRIn6tj/ACOLpClYFqQASHU/OraVYSmSlehi9A1YZgqW99wwKZh26kEoiexgFvA+Lo/yPhKOSzcnjZFachUjIGVt1C5uIZT0sJzlfeSdVtrd8XX5xdNVDGPpHEKbxaeoxa3wghLYmO9smJV14/EygfvF7+7+LN1ybb0JDz1Kail4Lnm67uk18odLxj8K3E9w5a4cMMTPi88cu45l/dOmK3S0NbCSLmFz7LYRynZbS0MgRuGnc7DS0lPcYqt0hTnXGMtl1vXUCghjNeH9fLznv4tZbO/qZzdn4skwD2skCqOM4chLQJlzXVXVxoNlU6TRtcG6rgKNMVVsDQoJuAohoTwG0yCILkG8nPOIE9ccQ+4GTHfSRB4KauwLN1AkvktwCOoOHvixDYNGwOIjVwCiNBU2LcdM+PxPNibcPigO9x7A+m4WFKbGuwyk+eEYWAtGD4Tw9vFUva4+TI2bdthGuaXjfVEKaJoa6FMOCZFN9ipFGKozPSMsgQzwhHFXLDhGAnGJYRjETGaHykSaiWpnClIFDSEBWcrXgu9Kk9U65tMpSWCUcKQpPK72vpgM2Ne4kqReJ3ZxdhkXF4aIqoLCYnm54mXJadII5VG04Zpx+UkombuiZnF+6QHdkoeZTmXcirisc3TC8oqoh0Hhlkoa5MPIcbSUSTXyE9Hl3wJ/wYF/U3l9KulDDcDGoupcFr2slbaUGso8fuydFIQcxnBeM+Aq+nyyytAmtSySCcRhBdaX1dYfjYXPL2kOkgtpuApcUrQ8tqERqN5XNeYRkOQ3mmto63Ttz1+TlKkaaM1WcSHqqKEctIajiLp2Y9paD8RQnl8cDStuttJbi2vmGsuCfFOMoYTxl2clqTAnuhh/SXjnfMPUuZfNG2KJkJIb58GSstzL2iankSCFz2vtCTQT47nYigFcJczrl2EoAMjBNlKmeoeinVNVHJe/ILJZQsWl/UrbCvJDdT1FLu6UFiue9AYFuuXD1W9XGfJF4slLzTigFt4Tm1t7cGUmV74/Drl2z94+tq/33r8pnI2vS/LInnjmjfjYgidLjc6Nzo7nZ1NOxuaeR3oe2NTYVPoEhi/GN1gBR2xk0BSYC5cG7lMbAZ8zM1pEj9V3ilSG3LCEBeEHnYw3YQnjO5M1JROZ9ONTRsttiY2uzW2243b7Ubv2yJCVDQ+VwC3gLYJhyJ2DI/blEj1jHiOxh3RmYEyutBQfJSohfUtT+OllSoMEXoTbIb0BYiFis8svX2ti+d9OHGvM12AyoUhTOPaRWnS0IS1rUINHEl4pxKhTNJK8bOcCre2aa27J3UqZ2RBOreLXNVrVsTDx3Rd7G1GflLG4fyM2FMDP/OaxRyfYbDjcxK4oRSVZPFD5sBpudyoRjws5TgiB23utILTlwsx8cz3C9o0PXwmdjxypdebX2u07inep/2Ggv42rJXbhxqi801Ds0o5GxvdN7o2Oo3gLg0hFKppxHJdWZtfimgSQhAQp6AJYYHmAXaAG62stoTXLkhi7phkcJ5xTdl3KziVhEJjoxMLKS70FnFlk9p0QXvj/vLC/eWVvm/01hGbNAFF0VRMFcHmEdftIB26NmRv8DQexxGK0XLDxBA1uju9C3LrjDEZ6qelChcGCkNgqrA3xVTB5oKQRXS1lPTiL0USJSRdr4Bp5tTSABRMLTZYU3hUIsjIj8+fJX+eK4Qo6K+JjPyD2pUAn97Vz4AULbWX+LxSnDO2DXiJexZ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
2020-04-23 18:24:16 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. What is a hook in PyTorch?\n",
"1. Which layer does CAM use the outputs of?\n",
"1. Why does CAM require a hook?\n",
"1. Look at the source code of `ActivationStats` class and see how it uses hooks.\n",
"1. Write a hook that stores the activation of a given layer in a model (without peaking, if possible).\n",
"1. Why do we call `eval` before getting the activations? Why do we use `no_grad`?\n",
"1. Use `torch.einsum` to compute the \"dog\" or \"cat\" score of each of the locations in the last activation of the body of the model.\n",
"1. How do you check which orders the categories are in (i.e. the correspondence of index->category)?\n",
"1. Why are we using `decode` when displaying the input image?\n",
"1. What is a \"context manager\"? What special methods need to be defined to create one?\n",
"1. Why can't we use plain CAM for the inner layers of a network?\n",
"1. Why do we need to hook the backward pass in order to do GradCAM?\n",
"1. Why can't we call `output.backward()` when `output` is a rank-2 tensor of output activations per image per class?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Try removing `keepdim` and see what happens. Look up this parameter in the PyTorch docs. Why do we need it in this notebook?\n",
"1. Create a notebook like this one, but for NLP, and use it to find which words in a movie review are most significant in assessing sentiment of a particular movie review."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}