2020-03-06 18:19:03 +00:00
{
"cells": [
2020-09-03 22:51:00 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
2020-09-03 22:51:00 +00:00
"from fastbook import *"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"# CNN Interpretation with CAM"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"## CAM and Hooks"
2020-03-06 18:19:03 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-05-19 23:56:41 +00:00
" <td>0.145994</td>\n",
" <td>0.019272</td>\n",
" <td>0.006089</td>\n",
" <td>00:14</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-05-19 23:56:41 +00:00
" <td>0.053405</td>\n",
" <td>0.052540</td>\n",
" <td>0.010825</td>\n",
" <td>00:19</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"path = untar_data(URLs.PETS)/'images'\n",
"def is_cat(x): return x[0].isupper()\n",
"dls = ImageDataLoaders.from_name_func(\n",
2020-04-15 13:05:34 +00:00
" path, get_image_files(path), valid_pct=0.2, seed=21,\n",
2020-03-06 18:19:03 +00:00
" label_func=is_cat, item_tfms=Resize(224))\n",
"learn = cnn_learner(dls, resnet34, metrics=error_rate)\n",
"learn.fine_tune(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img = PILImage.create('images/chapter1_cat_example.jpg')\n",
"x, = first(dls.test_dl([img]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hook_output = Hook()\n",
"hook = learn.model[0].register_forward_hook(hook_output.hook_func)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): output = learn.model.eval()(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"act = hook_output.stored[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-05-19 23:56:41 +00:00
"tensor([[0.0010, 0.9990]], device='cuda:0')"
2020-03-06 18:19:03 +00:00
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.softmax(output, dim=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [False,True]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 224, 224])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 7, 7])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)\n",
"cam_map.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-05-19 23:56:41 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy9zZIkSY7n9wNUzdwjMqu6e2ZFuCO8UGQfhI+z78EjH4HPwlfhlSK7nJnlTFVlhLupAjwAULPIqs7s6T6wD+UlWZnh4W4favj4A/gDKu7O76/fX7+//v5e+v/3Bfz++v31++u3X78r5++v319/p6/flfP31++vv9PX78r5++v319/p63fl/P31++vv9NW/9cv/7X/9r05lc0XiL0BUaNpwd+K3jrtjHn/X51SUpiA44tAEVEBx4rhOU8Chaf5O4vNNQFUQm4gd+DxgDnwOMEN84maoG00cBATLP46Iowq9OdLy+Pk3DVQd1EAsL5Y4eROkdVyVYcJ/+p/+if/0n/+J+6c/IK2DNPp2Y0xnmOHa0H2j31/Q2wat4Tg2J/N4MMeB+wAbzPHExwFuiNa6GRKnRUQQcXDj+Xznl5//HZsTmJgZ4rGO7o6ZYXOiKphNbE7cHIYhBgxgODbAp+JTsKG4daY1zDvTG2PCmMK0xjSYpkyE4cox4TDHXHDtTIfpwnCYCNNgGAxzDjPMAdGQFZdYHwdDQBQnZMgAc0JeUn6M+LenLKGSMmUpU/WZWJ/pxvADc8MkvquqyFVOPU/kgghILDpuHn/caa3R6przd+aOEs+jaYtjui8xEYnrUULu3PN8oogoLuuTIBLPFcfNsGmYGe6xNv/3//Wn/+P//H//9//6H1ZOkRA0ESAX0ktAQqcun43L8cvPIqAi4HFDef/rgeCeNyYX/Zf1XRGJx5GCyJj4nKGAbuBWjzMEcp1DziuROL5IvetgjksaEZFQlFxED9VGpLPfNj59/oH7/RV3QYjFjwfueWbD3Bhz0IYiYV1wd1QVb8ocglkZpBTQaaDgZjSNcyK2jNYychKnI4VZSvBEUhhTKHAQS8FY+nGuSd1fyY3Juv9acJc4hyHxXERS+SWfd/69nn8er66h1l7j+4qmbZeUJcFwxGqVUwykZOCUqFJIzevydeJ83pTQS10VpyOBvJk4TsqZLtlgKbHX8a4iEzefxlJQ0ZQjj99dpT8Xw/Nm3A1czk+IIghNa6HO6yt9+HOvbyonueCyLJavk66bWrd7/UcpZl54Cdy5cmmJPIReU0nye6VMUFbOYqGXxXE0F93SS9Z9Qz4EqYd2XpRDmG09jUyu6vq95/WKCK+vn3l5+UTTzkARaYjosuYqAqpLIabN8FgS1x22M02C5b2KQn6/rLe74JLXI7V+shQnlvEifNTvNIXh+gCuUiZLMX0dLz3KMk6Xc13XAEDaeh5AXKd7OSPKkKq2eNZlVJahq+9KrluAFV9X6pfHE7K0zO0yTufzkZKhfD7qZQjrNyWn9ezT4Od14+nYU1i8blgc7PJzviX5RshDGke3kEcC/fjlmZRwxfV/fCwiLZ+MnEbtqjC/8fq+cq4fLIBFWn8R4bcIDPWIZRmW9I61GEuhL/BjCdvH47k7c86wVn5aQTi98PX2SrkpeCi1YB+vDs+rSE/kKfxu+R0Ntfr0+ol9vy+4FrBIQSw/JkhvyNbx3jANIbGUcL8YJFEpyculVdznby87cQlNlWFzvVsGrDSo4G0oeb1f93dCLUsl9fP2Ob9weW9ZOA2UgK5jqAiKM+yqULq8pjphOLH1sEOx1lFO5FKn+hp6ua9Dez0bILH8ev4f5Loe8+V2RE7PqhJQ16zeK0cgsW75WWpNvJzE6VnB0i6mg7AZ/04dWNezjEQaMUDq/3mddr3Uv8VzSi7asoHLsl/Vi3NVPN16vqOei35+4lceUrXgw+WcFNwN4ZNpYBMsPVZ+MBY7DIFcjq/lccTXuZdAShoHP/3M9eVp5QXhfn+ht84Yhu5KPUZVxXzi4qgKrSmkck5yvRRsRJSkrSHemTYCztYx5qmcX3tNFY3vjSPWqeCnAx5Q2syWcspXyhbHUaQpGMzrb91PYbtCjguMc9rSZlFBpUVYrjMUEUE8hXwJ5YeTr+dZv/dlWEuWZN2vuye8LURjlHoulJFRp9SzvJ7uw9+SBkVp2hJel3Kf3hQ1VBpSho6rp497K0O3LMUHwyd1g8sbx3qccr4wQqLOiLN/7Yh+6/Ud5bxApqUUaX0uiyMFOK7WW8KyrHWQ8B5LofJ7koHqFTpougdzx2zCnJF5MFsPRbXu77RWp6Yl5MMKxZ5BfnrdSBylZc3Y3UVwFVwVVaX3jgPTbHmruA/F5wxrnMdVjffL0Ahg84CMDaFhUyPUMz8ha5mjdfn5kPMaHCI+1vxQQkGXSJZYet+LfsU58zrFYoFTLVYippIxV5UOYxmKKShioUSa66IIrUJu4nFwPd5X7qA85xLsdZF1AKg4fN1aGdpEMzhYIraVPpLLNcvlh7oLSQMrDW0dlRZefV7QVyIEbS1jU0PtItMXb7GcdNrHOLYjmlD1Cm0zPo3HdDX/V3Rj/Bbq/Pr1TeVUjxjP8oZ0CUxkI5flKEMgkTQpC6JNAqdHAJoWsrzT6Xsdiwe91uNMPk2zzM5GgqGsZsGUBRsvFtyQWHDOOKalVW5NaE3inHOEh9K24r64b2Hfbpg7z+cTb53hk2ZC23ooJyH8SAAXs8mwiamgrdFEllKJJqxpgjZl+uBxHOy9Y/N8aH4RBCQ+i4Ab15VKwYh7CEPnp7UuI+iCG8xhzJney2P9ZynTRXZEE8a6Ih74a+EjjwSWeXrW9A5gmRvLeyi4mIpPQSQP+GgrkFN8WijdRXDreZsbIgH7zSqvAFS21a20hIDgFvF8XrFKQ7Uj0lZyq5I0lp64Ukkzk1Pn+i8wfaI5kYztfRlNIdZRNdCfp5Ei5ZZcp9IdLy+c6DKQwN+gnJkEj2BfssyhIby0E+YK4X0SKVHJC0uBKatouSiZSPuQVCprqNRTTS/plRA6rc0Jqr2c9CXlXl7MQRTRWkBHJUx9ebnyiHNMUEW2jvaOtE7vG611RFsIXVO0t6WYqro8rDaN7ycMnzbx3mLNMrY0D1GUrBfZ8Er6XiD9aWVXNjITKwXJr09UlICWC2qVdldWtfICJcSR4DsTWYovhVQUxazizHhmpejqihb0k3W2Ol0KHhl3npL9MVUTgjmtIHndTCrd+ncgpg9JoXSnBW6vKIN8HngihjJltVwlG1aKYQuFqCqtjFqtecnkUjajkoRVjjFnGXbVqGosGTXLzxih05X3OIEueg0Lf/v1HeWsm9eLsFzD+3R3acU0H4j5Wc/pW88Fy/omqwIEtciWFSw/SxyeD8jmiHhzZUjrdOciihDJnGWBSQhadanz7/K8ILQWyjY94a+DaKPvN/bbndY3RBVDaCI0aSuWUFVGKiNWRiys6JjG8TxQTVHLmAnCQzVtTJX10Jd2Us4moW0KTAHTqnFWUkQ8DIIsoyXnURJdlFxIGkrVzHJaKKatdE0onXwwAKV1rGe9cgUSCis4ahFmlMIFqrhERLXkS4G/QqL1U1p3T8RW4cpleZahcieV9ISUpxKXhxrLa4IHmrhclbstD3rmQuQMH2q9CeE5s9qJFDSShKKa5Ze4piiROZFCyxyMs651xaPf0c7vZGtL01OVLvg+XL18tcinbXBON1/WWokMGRePWV4zpOiUJDNjzoM5B2IzSifrIRRZ4YRdkbQILx+Kcrn36/kkhEhwdFO09YTVhqVHub28cnt5pW8boi1jyxTIPIaoMI8jDIaGouvW6a1H3fN5xAPShF0ZvwuRhe294zPSNOIFzE8DInmuSJjlIiW0WopeyCIFHwlPGkIlK3FmZCkFIfzjMotZqI/1P5VIqNjeXbD1B1aAnoordRCvZ5OkM796r+uz00iGuyzBhazlUjdykSY5/5a8ZxeJRJieH7v8lWgqM9nu61paZp3PxJrG+i7HE89h5U/yFTmDlOsrcim9sIqHzyRQZcYlCQ5iH+/tWiL6c6/
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_dec = TensorImage(dls.train.decode((x,))[0][0])\n",
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
2020-04-15 13:05:34 +00:00
"ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
2020-03-06 18:19:03 +00:00
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def __init__(self, m):\n",
" self.hook = m.register_forward_hook(self.hook_func) \n",
" def hook_func(self, m, i, o): self.stored = o.detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with Hook(learn.model[0]) as hook:\n",
" with torch.no_grad(): output = learn.model.eval()(x.cuda())\n",
" act = hook.stored"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient CAM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class HookBwd():\n",
" def __init__(self, m):\n",
" self.hook = m.register_backward_hook(self.hook_func) \n",
" def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()\n",
" def __enter__(self, *args): return self\n",
" def __exit__(self, *args): self.hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cls = 1\n",
"with HookBwd(learn.model[0]) as hookg:\n",
" with Hook(learn.model[0]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-05-19 23:56:41 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy9S5YkyY4ldgERVXP3yMxXr8hqfmbNVXBHPeYGeiFcGhfAc0g2q6pfZkS4mYoAHFwAIuaZEdH9esAapGbGz93NTFUEnwvgAiLujj+vP68/r397l/7/fQN/Xn9ef15/fP2pnH9ef17/Rq8/lfPP68/r3+j1p3L+ef15/Ru9/lTOP68/r3+jV//eN/+3//U/MJkrgNRXBaoCbQp3hwOAAw6Hm8Ei+ysiAAQCwNxhW1LYHJj1s1Lv4xC4IP7ND7U5YXNgjgtjXJhzwmxi2sScBvcZrzWYGRyGCYO4QEUgUIgKFAoV/sK6M4jxntzzawpFg0MxBvBP//Q/4b//p/8Rr6+/1Pe6nrBhGIOf3fqB4zjRjgMQgcNh0zDmHXMOmA24D8x5weYAYDSLApgb768JuGQOiOPxuOO3z3+DOZ/R3Pg9ABZrPW1CVGDg390cDudzi9S+iKQNVrgIoAoR/t3c+fDgz8I99g5cSeP7aWuI3V5X7pU73CcAQdOOhsZ1dwCuaFCINKgoWtwb1xsQQX1NhXfRFFBVqHNnFahfIoDGjsMN7g4VxNop+KgCgcfn8/4k1wSgnLjBzdE670ti39wcZhQKVUHTtj2whzdzqDgEFmsQz6IKyTXODRaFCCBwwAw2DTM+Hw7Y//nv/vf/+H/8x//wR/r3XeWENIiEdoIPyVvjDaXiQkJuRMoVh24CLpQ3hCIiFTEWTtYLygDEIpagm1Eh58QcA+YWQmt1T/hGSYhvJRCNG9qeQbbPlfjPIXAXqHacLwc+/fQTXl7e4FBICF0+vMPhYWhkDngIMZSCrqpwV5ihFIf3ANg0oNFwtX0t4ZB471j2UJpQZO40RBQaEqtOhU1LU4ZRUjnXmnoIcn5NSuVk7a97fE8hml9b+4ElEmVEAS0FEMk9pxFIo6hQ5C5zpZfRl/oVe2R8fg2ZkiV5cH7a2jvJvd/uLeUr98odXnvNz7FtfZfzydUImdiUGmG84CXBT/LEv1vIULxL6E/TeJ+1+LvH+8PrB8qZXiaXxMJK1q0BHouTglWvBSAOGohYJMln8xJWykFarnqblGH+nNn6MxXS1qKj7o+CqdgffluBXNtdMLafsfzs0JS31094ff0JrXVMU4g2iEtZXhWBayo+LbLBgbmUIp+L9yylODCn0sJojz0do9etL0EXOGxtS+476PmWXZK1MbJ+ZheIVO7cH4ldS7G1gDgqAFSf4p56jvoMPnuikFx/CrVCXAOxhHcqb7rf4tqL9W/fDESugS/ZgNea7zuc8lRyWPatBA+I+xUViH10OOFltxVBoDhJlGAG+ISLocXX8rnL8Tgop9jXXWuVdvX/3vVD5SQ8Cou1eSdZa/zhSoAqpQjLuW2vF8B9LT42GFbv5I45Zy1aWYXNci/DsT1qyd+6S/ftnl1KjtPoWn2Glm1/e/uE87yBKIfCBRGIzPBAgtYaRDtECbYsoT7WmqWwhtYCv/v3x3UnpKLntfpigRHklnhA+W1d8llDg3Wz1mU6HfBCRAijuYQoN0ggsVQUzjIEqfjlKRUfJWKhES0Flu1ZJZyKbM+1i+tHJ1CQe3MAT44gF0UAcd5vhjPAQi51z3C4rndJZ5h7L/V1B9yI8HK93aBiUJWAzPtN4MkJwBUQhznhbSrnQhzfvr6rnJJQMxS0QFDFDfExdSNAwr20OO4LkCS0oldQqKQUye9ldBO+hLVmhLO1fOltfakoY5dN2T3fPRSldnuDLLvipoCK4PX1Fb11jGloTQuuqCoklEJE0HsL5RRMoy8UAeYYcADaCIndB+bgvWhTmM2C9bsfoNNSNG2Y84p/AxaYgOu71gX19LvlIiSUps9C/mylctXia8sTikp82QNCK+DLe+da7wK2qTtWuIC1v2WoYs0V8WxS8WRKS8Wy0MhNOGNQiRfucor1fGlIEk53bSDCsGXgwxMbvBSMyvvhfXYFXUKZD8+1KkuZdyLLsG13mE7KPBXzg0P5g+v7yum2GdcIhuOD4b+PGZ7uc7tXj4ctF49tfVPJN2EJVWIiyYxJoMEE0FPCyQF4LK7otrTLkpeVcm4u7UHAJEHEQwIDoCbhVWNj+0GFmwZptNoqsrykOUQjBlJ+pohFgkMxbTD3o42eyhQmFJCE3rvwlheJBdKmwCWhwJEwCotjcJhP7HA314UGSpmgkEzceSG7VEiVZ8e9YtWE34b0vppJjtg3Ln2ikvBKqJuv/Zf9vvPLC4whY7LNIacoxKNa3MXKTziWkXmKB+sDItbVRDWCtqEwuMNVADOINmSsyBhsOZD9fpH3HPuv6hBpgOyIUssBAZl/yZveFNw99m3Ls/zB9V3lVMwP2SsNqOWwOZfCJYQCQkniHlXgk/FBoviyHLURwg0wi+89b8CMDBdDoYAoFWtKKCdf4eFJMzPslgsaOgcmaboq5jTMMSGNygOXeCFjVsJZx+PxgKMxMwrgOE7GmgCVVCPus4FpwISjNWU8AgdaJmYAaQI1fvb9euA4DrjPZ8iN5ZVUWwiFQyQVzMJ+CGwa432PrQ6jIKqAKMwFPmx7XxohNwfEuEmJYj8Yz/VV3rxNvk9mux2OWd4tfloi4eOykJKEEjuXWMP6+EylAyQwgTnjdpFIfDmNcWVy0yO5I0Mu3quVZ0iP2VoHoDBEzO6gF468CW2CrnAGzx6tZDj+kigxdUDFMc2fdKKqDgVGpOB57a4lKjCisD/UPF7fjzmR3oL/ypR1WtEMW9LSJnTJh7UIwNMKWtz4yuds2bRv3sJKe+/4LCF2vs4STrrD4p48PUBY/swaq1IFs4wwx4RDoa2jyQGRjqMfaK2HZaUXoyXWgLatFFSbRknD6eFt4jiYtVWRlV1GJDJcYGNtZMYsaa0zHssv1N5+EBzNjXlak3hZGqgwhnusBfGl/Jr3oKWI+5tKeT+JkCJu+MmgLA9KZMO8rMvz9wGnwTCDm0ErNmyBGJaBoJL6grsV/hQQL6/MXAANmETiKT81kzOIZI45P5v3bmha6cNAP/HcvqSLkNc2RAGYScF91cZ3y/JgyrQ7bEs6bYCChuZ7bhM/VE4pC8D3U2TWiSFIfmhCQT6bhWc1d2jrgEpkJvk9dQlj51Gz22KMzCC6wWzA5mBsVmWTNBYBrcOzJDQp/S2nkDHIBnZTgBpLIz49k79AU/TjhvP2gtaZ6LFJj8GNpOaIKr1pbIgi4aRgmOF6XFW/RAktlbm5MtaxjINyy6g8sr1XlhL2mD4lT1Ydaj0rBJHpqLWQ2kNZaxfvvQGyUKbNCO4ZnMwvhDdkzkGgeVugMy7vIdvbxMszSZiPsr15/S4B+TJrmt4ylaJ+xg3uup4vnyWWhGhmwjFrzfbYcYfiGaMu3BBqmZWGWqEwECWHGhCXnlMUUHNEzv7J0O35jufyzLevHyhnPG1mtbbNzFWIxOfyALLsKYU2rVIkyac9QyFIwJ4ozIZJdjfMcbGIP+cSjhSzgKpuaX0R1tp3tBawD9unLaVo2tC0Y4oD02AiQGu4vb7h9vqKfpxQaQHrI45DgAGRIEjQIPTzgLaG42jwy3BdA2JO5S/jQWFoqui9w21GiSCVEuuXRH4462MuyFJWCpt4lQPXC5FeN5QR65urzBT7lom0ikmXUElo0Cp3xX6qbrHpSv4tRVJmSytJ5wEtPdZwPY/6jsT+KP6S5+cK0cpkJCG+bj+dCkxFWJlsr+dfCAJww9O/S2F2zXIw8VOLvAwNAmlYxVGy3YfAxSp/IBVkf7zXb1/fV87MRiUydUIB8VACTXAZIMGlFDWCTgqvLgYGfGVggVycvPEFB9zJBErFlA0aVDYWCtUZ9UOlKAg9aIXCtdZeVpFWld4TEvUobdB24nx5w+unn/Dy9gn9PCHeYTajkLKEW5oSrhowRSC
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with HookBwd(learn.model[0][-2]) as hookg:\n",
" with Hook(learn.model[0][-2]) as hook:\n",
" output = learn.model.eval()(x.cuda())\n",
" act = hook.stored\n",
" output[0,cls].backward()\n",
" grad = hookg.stored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w = grad[0].mean(dim=[1,2], keepdim=True)\n",
"cam_map = (w * act[0]).sum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
2020-05-19 23:56:41 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy9X5JkSW7u9wPgJyKzqrqHQxrNaJKume4KtAFpPdqLlqC1aCt6vWY0icOZ6e6qzDjugB4A+IlsDYea4oP40NGWnZWZ8eccdwfw4cPncIkIfnv89vjt8Z/vof9/X8Bvj98evz3+8uM34/zt8dvjP+njN+P87fHb4z/p4zfj/O3x2+M/6eM34/zt8dvjP+lj/LU//s//0/8SyeYKItfvVRVVxbhhcucWLxxxY7hxoByq3FW5mXEfwiHBkOCuwqFwaDBwVgSB4AEBBIoDERAIIcJajs/JnCdzLtZcrOUsd+Z0wpNt9gjcI79H/gxCICDUd8n3luCmyk0jvyS4GxwCLybcFG4KGs5//0//yH/3T//Ijz984rgHdsDtODjfg/e3hS8FORC7g9xwlLXgMeExJ49zstxZvnjMxVwTIhClrssRFUwFUUE1R+L98c7Xrz+x3InIL2oOVginB4/phCqnB+dypjvTcxxXgAdMh4WyQvAQEMVUEVVEpN4yMASRQAHNoUKiZ0U4hiHkPPH8Pahx914ciAgRgkewHESUQEHAQ3DAoe4NQoSIYAERgRP1HsF0JyKuD6xZjajP9CD6ekXJjxGCwEJyfUS9X42fe64fj8BU+8U5tu64572IgNTfY48FIFErK+ckIqdGRPNeRci4l++b4xyEO77y/XNO4b/9n7//3/+PP/5v/+vfbJxgiFyjEuSNxjam5wGrSUVQEVRBtS6aWgSy3wgkBzj2L6X+L9fzJJeDu+NrseZM44zAV0A49LXsklAOxIcLE9m/zed+fH4OYF7vfpkY95vy+fMPvLx8IkLr93X/4QR9Dw5rgighA0cJ0omZWk14jl3Ppi8Hy+uwdh71fsjz6EqPPLjnIJMTrqp4Xbv0fyIQ9b2ut8d/z8G1bnrE84ft1CgjVa4ru6bu+l4LXkFcr7GWvH8h10Eu2PpQT+cpcb2v1BwqabSy5zNQueatBr6MI58f8qurC9JYg22YET2uz0tC0F7IUu+9Hx2Q+kuJnpPYq6ieen1mzp9DPF2zKEI6X3ptCzVH/NXHXzVO2au1LiD8aaDYnqjmBGrA2hG1Yef116KW/Ld7L+zY3qVHL/a7Us/zGuD6/jTo/uzC41poSkaO56mDnLAgvW1Eu478VA/Z1ykifPr0hdfXz6gNgkjPWB6QyIWDyRUN3HEmy5W19iDkpXkQIQiajsnrd5K/j4hapFEv64WRf6ecgUQbXU5ur9c9E78uW8tltOK9mNmLRJ5/4HmeMso+L6BAcvFRRlTXgBiqNa49/pGIKNeDdszJMShDkbhcZhpmGp1EbCO9HEdctxbPjiW2w91WXuu1Zzb9Yd6X1jpTgVAtW4v9Mur+5FqI27B7DeILcEx/5eRxQK8xfJoWkUYeUqhwf8q/+fjrkVOeU1Iva4zrWv7iI8O9hO6bDi5Dbk8lIhkg2rJrEX94pwh8rYIBO0RSI79fGs+X9Bdvuie2F6C0l9h/TTgcuPQCFj59+sztdic8kGG1GAteShS8HwTGcuNcbYROoDtCt/cV1jW0Ihlxn2+6x6buS1WfIFZHTPY4JKR0/NdvUwtVRFDRMm74aGm9wnsScnyk4Oe+ZgFFEMu1vzyelkA9lzLicALvIFzv2vBZypF+xEqyrykdj7SBdSBo5/Yhaj29wQfP295Y6roSXVBz23PRRl3YZ0flNjR5+qICSP6pIKk7gmdULwTQnxt157EdTP2/htt7+q7l+G8+/h1Ye4XeKGDfAGlD1fIFFFSBjAouTrgQWhijffQTDFMVQrT95ZOhXRHXKwdorO7RXpG8nhrQa+3lxFSMywVRcNI/gDJ5moQLCl8LRnh9eWWMwVqLMUZFDjArP985mhpilveycm2YGOdcCddUMTNWOHGmgaop7msbzLN5NGQ1M+acT1GioGOjBi8E4r+OoD13iorifgWVvdDkchzXIsk51nIMVOSWylM/RJgap3Zke5FzLTx5upbgKarJ01r4cMWg7SD9aQ204UTFnSeU9eHF+77Z925qGbkj88yGfE6g7ogavv92rZt93R1gnpziXj89dpGpQzuERl7PqKTfxzuVkw+T9Rcffx3W4mUIDT2i5y8vSEBrJvbYxK++dyQSyXxJe8LbvHVDjo+5YC/AIoBWfs8oWq+NawCpPI9fLRZvT8YTfCSQSAiikvnAUMUITC/Ca4wBIUx3boCHV46VuZg7GUHrc8UUgcwv3Dgb/qphAuIL32MnZD7ie7j2ZAtIXcOODtLRtiAWQcR6ytt/nXsq9Fde7V5bTiAOYu3na/ak+AJJcioBWDmLJ9KkDW3P1FPUufL7a7FSTjckKqJfUHF7iufHU4AMnvP1a03tf27E9eRwyzBFBmqXcYasdO6dC6kiakg7+w8ergb+GWwUpMn3zzkSfgVtRZ9e+uxE4ild9afX/NuPv2qcWpPvRbz0oiWSoEFHegDpqCNPHksx0zRErYWrNbnyFE1DrghZudvzpC/3NMoAdt7SYyhX1KjoElE8WkeW+ltEmq9pLUACXxNXA7E9gJAL83bccQ/ezweHwfSFOdzMUJUkYjSZOSdYvjin81ggMgoR5L1HXZeYoKastXh/TG63kZC0Pvo5T0lkYdBMYRmk49vRObXY9wLXJwitydYuZ0aypJcDdIrYfMpBa0F3yGvvX/929wtl1GuW+44Gz4ijo0Q7mwaPIUHyasmi+hOWcdgRzN3L6NoZ13p8ev92ROkMfOeOfR+qAxHD0UQNITiaDrYgZ76X/yo//DXWlLqWVY623ptgRaBaCDAuMi0qZdBah7QDjXZkXtzHX3/8O5GzI2bOfQW+7Y21vLkiaYwMRiiqssN7U7YZBSFU9mDsPG9Drmt8Li8T6Wnae7aD/hWUau8aUsZAJEPYUboQgNSAmihW7+dr4QhjGGMMzJQxDnQMRA3EUVPUyiDJaJiQPb1vFFvpvljxyCFSkJCCn7kURdMh+fR93xuibyhYBlApfzOBzyx0T4OWMwqvvzYpBoUU9htchidSiyqdbTLsbZwbg9RcdDTlQ7Tc19DX31PVC/RXyDOnM6q8sup7R259mv8cX/d15YDleLejoqHhhb6kHVUPPBdKaK+0IvPFnR6Ra2GvKJXyS7WWGtKWQTdJB+nsMv2wXAvpwfL6CiJ7eAUPfxqx+r9+hPR/6fFXjdPkBSRqKHICxQQJYYljMRhyZ3DD5OAmg4ExlISNEZgahwlDYCgYMLTery0wYO1w2AMShM+Mbr6I7UE7p6hoKGxqvvF8ExlWHnZDxcjPVxFMhMMGh0rzaxktVDjud24vLww7nsoJipYXzvKWFqLwgqaKKQwjDW+9o2IJfcJxWRzq2IAhgnqgLALHkIS9BZeSBYQpwmjauaNMkU0CGOkY2+D6RjtKKlH33LBXdh1QtRDOnmtQudhSoZwxu/KSZI1Ifn7B/+XBRZIWBH3Ko59Q+f7Sp7nLvxWTXU6o/8tRLaOsW9wBMngKGg2d0+l0vhexkLKLiMDqPdvveQRW99Kr4AoMwXJBIjnkdF4N7TNJUFGGjkSTwIqVV67pkAVhFSHU6dSeh5o3/Ssm+leNc+j9w88qXesTVALzwYg7Qw6OMEyVIYJpYJ6xzQSGSn4JGF6+zuv6hBBHwje5Awmj5jrxNYm1coqjc4q8OaVKIpJ5nj7FlmbGekKhnlOYnxBsGMcYKPn5adHGy8sr95dXxnGgplTIz4H0RMEiwlozmTsVZICacWB4BI/HCTJBByrB0InEaiTPwFhr0qyvSYEMBSuofNdgGohnMT9rIQmLDDCCFWA7MuaXb2Sa2HWVEa9yXBAfDI+aj3TAOdQm7QLaaea/0/GlEwyJEgBEzZkgSjlInuBtOYrKZT3SEbh4lbtKiNJ5aDmbzfrKE10nV1T
2020-03-06 18:19:03 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_,ax = plt.subplots()\n",
"x_dec.show(ctx=ax)\n",
"ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),\n",
" interpolation='bilinear', cmap='magma');"
]
},
2020-04-23 18:24:16 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-19 23:56:41 +00:00
"1. What is a \"hook\" in PyTorch?\n",
2020-03-18 00:34:07 +00:00
"1. Which layer does CAM use the outputs of?\n",
"1. Why does CAM require a hook?\n",
2020-05-19 23:56:41 +00:00
"1. Look at the source code of the `ActivationStats` class and see how it uses hooks.\n",
"1. Write a hook that stores the activations of a given layer in a model (without peeking, if possible).\n",
2020-03-18 00:34:07 +00:00
"1. Why do we call `eval` before getting the activations? Why do we use `no_grad`?\n",
"1. Use `torch.einsum` to compute the \"dog\" or \"cat\" score of each of the locations in the last activation of the body of the model.\n",
2020-05-19 23:56:41 +00:00
"1. How do you check which order the categories are in (i.e., the correspondence of index->category)?\n",
2020-03-18 00:34:07 +00:00
"1. Why are we using `decode` when displaying the input image?\n",
"1. What is a \"context manager\"? What special methods need to be defined to create one?\n",
"1. Why can't we use plain CAM for the inner layers of a network?\n",
2020-05-19 23:56:41 +00:00
"1. Why do we need to register a hook on the backward pass in order to do Grad-CAM?\n",
2020-03-18 00:34:07 +00:00
"1. Why can't we call `output.backward()` when `output` is a rank-2 tensor of output activations per image per class?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-05-14 12:18:31 +00:00
"### Further Research"
2020-03-06 18:19:03 +00:00
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Try removing `keepdim` and see what happens. Look up this parameter in the PyTorch docs. Why do we need it in this notebook?\n",
2020-05-19 23:56:41 +00:00
"1. Create a notebook like this one, but for NLP, and use it to find which words in a movie review are most significant in assessing the sentiment of a particular movie review."
2020-03-18 00:34:07 +00:00
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}