2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The training process"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's start with SGD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_data(url, presize, resize):\n",
" path = untar_data(url)\n",
" return DataBlock(\n",
" blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, \n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_y=parent_label, item_tfms=Resize(presize),\n",
" batch_tfms=[*aug_transforms(min_scale=0.5, size=resize),\n",
" Normalize.from_stats(*imagenet_stats)],\n",
" ).dataloaders(path, bs=128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = get_data(URLs.IMAGENETTE_160, 160, 128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_learner(**kwargs):\n",
" return cnn_learner(dls, resnet34, pretrained=False,\n",
" metrics=accuracy, **kwargs).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.571932</td>\n",
" <td>2.685040</td>\n",
" <td>0.322548</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.904674</td>\n",
" <td>1.852589</td>\n",
" <td>0.437452</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.586909</td>\n",
" <td>1.374908</td>\n",
" <td>0.594904</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner()\n",
"learn.fit_one_cycle(3, 0.003)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = get_learner(opt_func=SGD)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.017378008365631102, 3.019951861915615e-07)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhc9X3v8fdXI412a7HlfTfGYDYbGxNKYxxCEhIIJCHNpU3a0NBCkiYkN0/TNM29tOU+Cbmla5qnpQQuoQFSEghroEAbtkAM2JjF7LaxZEm2JWvfpdF87x8zkoUsyZLRmTmj+byeZx7PnDkz58NIzFe/81uOuTsiIpK9ctIdQERE0kuFQEQky6kQiIhkORUCEZEsp0IgIpLlVAhERLJcbroDTNWcOXN8+fLl6Y4hIpJRtm/ffsjdq8Z6LuMKwfLly9m2bVu6Y4iIZBQzqx7vOZ0aEhHJcioEIiJZToVARCTLqRCIiGQ5FQIRkSynQiAikuVUCEREMsCjrx1kV0NnIO+tQiAiEnLuzpdv285dL9QG8v4qBCIiIdfRF2Ng0Kksigby/ioEIiIh19LVD0BFsQqBiEhWakoWgtkqBCIi2UktAhGRLKcWgYhIllOLQEQkyzV39RPNzaE4Ggnk/VUIRERCrrmrn8qiKGYWyPurEIiIhFxzVz+VAZ0WAhUCEZHQa+5WIRARyWpqEYiIZDkVAhGRLDYwGKejN6ZCICKSrYKeQwAqBCIiodbcHeysYlAhEBEJtebOZIsgoCWoQYVARCTUhlsEJSoEIiJZqblLLQIRkax2uBDkBXYMFQIRkRBr7uqnrDCP3EhwX9cqBCIiIdbc1R/oiCFQIRARCbXmrv5A5xCACoGISKgFvbwEqBCIiITa0LUIgqRCICISUu5OS3c/lQHOIQAVAhGR0OrsizEw6GoRiIhkq+YULDgHkBvkm5vZXqADGARi7r5x1PNbgHuBd5KbfuHu1wSZSUQkUwwVgqCHjwZaCJI+4O6HJnj+KXe/MAU5REQySqpaBDo1JCISUqlqEQRdCBx4xMy2m9kV4+xzlpm9ZGYPmdlJAecREckYM6KPADjb3evNbC7wqJm94e5Pjnj+BWCZu3ea2ceAe4DVo98kWUSuAFi6dGnAkUVEwqG5u59obg7F0Uigxwm0ReDu9cl/G4C7gU2jnm93987k/QeBPDObM8b73ODuG919Y1VVVZCRRURCo7kzMZnMzAI9TmCFwMyKzax06D7wYWDnqH3mW/K/0Mw2JfM0BZVJRCSTtHQHv7wEBHtqaB5wd/J7Phe43d3/08y+CODu1wOfBr5kZjGgB7jU3T3ATCIiGaMpBesMQYCFwN33AKeNsf36Efd/CPwwqAwiIpmspaufJRVFgR9Hw0dFREIqVS0CFQIRkRAaGIzT0RtTIRARyVYtKZpDACoEIiKh1NydmlnFoEIgIhJKzZ3JFkHAS1CDCoGISCg1dPQBUFWaH/ixVAhERELoYHsvAPPLCgI/lgqBiEgIHWjvpTgaoSQ/+KsFqBCIiIRQQ3sf81LQGgAVAhGRUDrQ3sv8WSoEIiJZ62B7L/NUCEREspO7J04NqRCIiGSn5q5++gfjzJsV/NBRUCEQEQmdg+2JOQTqIxARyVJDcwjmqhCIiGSnVE4mAxUCEZHQOZAsBFUl6iMQEclKB9t7mVMSJZqbmq9oFQIRkZA5mMKho6BCICISOgfaUjeZDFQIRERCp6FDhUBEJGv1x+Ic6uxP2WQyUCEQEQmVxs7UTiYDFQIRkVA50JYYOqpTQyIiWWpoMpkKgYhIljpcCNRHICKSlQ609xKN5FBZHE3ZMVUIRERCpKG9j7mz8jGzlB1ThUBEJERSPZkMVAhERELlYAqvVTxEhUBEJEQOtvcyN4UdxaBCICISGh29A3T1D6pFICKSrYYuUak+AhGRLJWOyWQQcCEws71m9oqZvWhm28Z43szsB2a2y8xeNrPTg8wjIhJmQ8tLpOoSlUNyU3CMD7j7oXGe+yiwOnk7E/jX5L8iIlnnYEfqZxVD+k8NXQz8uydsBcrNbEGaM4mIpMXBtl5K83Mpiqbib/TDgi4EDjxiZtvN7Ioxnl8E7BvxuDa57V3M7Aoz22Zm2xobGwOKKiKSXnWtvSwsL0z5cYMuBGe7++kkTgH9iZltHvX8WHOo/YgN7je4+0Z331hVVRVEThGRtKtv7WFRxQwrBO5en/y3Abgb2DRql1pgyYjHi4H6IDOJiIRVXWsPC8tT21EMARYCMys2s9Kh+8CHgZ2jdrsP+IPk6KH3AW3uvj+oTCIiYdXZF6OtZ4BF5UUpP3aQPRLzgLuTK+jlAre7+3+a2RcB3P164EHgY8AuoBv4wwDziIiEVn1rD0BaWgSBFQJ33wOcNsb260fcd+BPgsogIpIp6pKFYPFM6yMQEZHJqWsZahGoEIiIZKX61h5yc4y5pTOos1hERCavrrWHBeUFRHJSd2WyISoEIiIhUNfSw8Ky1J8WAhUCEZFQSNdkMlAhEBFJu4HBOAfae1mUho5iUCEQEUm7g+29xB0VAhGRbJXOoaOgQiAiknb1bYlCoD4CEZEsNdwi0KghEZHsVNfay+ziKIXRSFqOP6lCYGarzCw/eX+LmV1lZuXBRhMRyQ51aRw6CpNvEdwFDJrZccBNwArg9sBSiYhkkfrW9E0mg8kXgri7x4BPAv/o7v8T0LWFRUTeI3enriUzWgQDZva7wOeBB5Lb8oKJJCKSPVq7B+gZGEzb0FGYfCH4Q+As4Lvu/o6ZrQBuDS6WiEh2GLoOQbomk8EkL0zj7q8BVwGYWQVQ6u7fDzKYiEg2CEMhmOyoocfNbJaZVQIvATeb2d8HG01EZOYbmkOQCX0EZe7eDnwKuNndNwDnBRdLRCQ71Lf2UJCXQ0VR+rpdJ1sIcs1sAfAZDncWi4jIe1TX2sOi8kLMUn9BmiGTLQTXAA8Du939eTNbCbwdXCwRkeyQmExWlNYMkyoE7v5zdz/V3b+UfLzH3S8JNpqIyMzm7uw91MXiNPYPwOQ7ixeb2d1m1mBmB83sLjNbHHQ4EZGZbM+hLtp7Y5y2uCytOSZ7auhm4D5gIbAIuD+5TUREjtGOmlYA1i+tSGuOyRaCKne/2d1jyduPgaoAc4mIzHg7aloozc/luKqStOaYbCE4ZGafM7NI8vY5oCnIYCIiM92OmlbWLS0nJyd9I4Zg8oXgCySGjh4A9gOfJrHshIiIHIOuvhhvHGhn/ZL0r+g/2VFDNe5+kbtXuftcd/8EicllIiJyDF6ubSPu6e8fgPd2hbJvTFsKEZEss2NfCwDrMqVFMI70ntQSEclgO2paWTmnmIriaLqjvKdC4NOWQkQki7g7O2paWLc0/a0BOMoy1GbWwdhf+AakdyqciEiGqm3p4VBnfyj6B+AohcDdS1MVREQkW7xQk+gfOD0kLYL3cmpoUpLzDnaY2RGrlprZZWbWaGYvJm9/FHQeEZF021HTSmFehDXzwvG39qSuUPYefQ14HZg1zvN3uPtXUpBDRCQUdtS0cOriMnIjgf8tPimBpkguTHcBcGOQxxERyRS9A4O8Wt8emv4BCP7U0D8CfwbEJ9jnEjN72czuNLMlY+1gZleY2TYz29bY2BhIUBGRVHjjQAexuLNuSXpXHB0psEJgZhcCDe6+fYLd7geWu/upwH8Bt4y1k7vf4O4b3X1jVZXWuhORzFXd1AXAyjQvNDdSkC2Cs4GLzGwv8B/AuWZ268gd3L3J3fuSD38EbAgwj4hI2lU3dQOwtDK9VyUbKbBC4O7fdvfF7r4cuBT4lbt/buQ+yesgD7mIRKeyiMiMVd3UzbxZ+RTkRdIdZVgqRg29i5ldA2xz9/uAq8zsIiAGNAOXpTqPiEgq1TR3sayyON0x3iUlhcDdHwceT96/esT2bwPfTkUGEZEwqG7qZvPx4errDMcgVhGRLNDTP0hDRx/LQtQ/ACoEIiIpU9Oc7CierUIgIpKVhguBWgQiItlpaA7Bstnh6ixWIRARSZGa5m5K83OpKMpLd5R3USEQEUmR6qZuls4uwixcF3h
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.969412</td>\n",
" <td>2.214596</td>\n",
" <td>0.242038</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.442730</td>\n",
" <td>1.845950</td>\n",
" <td>0.362548</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.157159</td>\n",
" <td>1.741143</td>\n",
" <td>0.408917</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(3, 0.03, moms=(0,0,0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A generic optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sgd_cb(p, lr, **kwargs): p.data.add_(-lr, p.grad.data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_func = partial(Optimizer, cbs=[sgd_step])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.730918</td>\n",
" <td>2.009971</td>\n",
" <td>0.332739</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.204893</td>\n",
" <td>1.747202</td>\n",
" <td>0.441529</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.875621</td>\n",
" <td>1.684515</td>\n",
" <td>0.445350</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit(3, 0.03)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Momentum"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hU1dbA4d8mBEhACEhNAEFRBGnRWDCCFAUuIgYVBXvnqlgRDVe5dkWxtw8VvIKiwBUMCCgXBKQjwURRBKkKoYUSigQIyf7+2DNhMpkWppwp630enpCZkzkrgaw5Z++111Zaa4QQQkS/SlYHIIQQIjQk4QshRIyQhC+EEDFCEr4QQsQISfhCCBEjKlsdgDt169bVzZo1szoMIYSIKCtXrtytta7n6rmwTfjNmjUjOzvb6jCEECKiKKX+dPecDOkIIUSMkIQvhBAxQhK+EELECEn4QggRIyThCyFEjJCEL4QQMUISvhBCxIiwrcMXIhpl5eQxctZathUUkpyUwNCeLclITbE6LBEjJOELESJZOXkMm7KKwqJiAPIKChk2ZRWAJH0REjKkI0SIjJy1tjTZ2xUWFTNy1lqLIhKxRhK+ECGyraCwQo8LEWiS8IUIkeSkhAo9LkSgScIXIkSG9mxJQnxcmccS4uMY2rOlRRGJWCOTtkKEiH1iVqp0hFUk4YuoECnljhmpKWEZl4gNkvBFxPO13DFS3hSECBZJ+CLieSp3tCf0k62BlzcJEU1k0lZEPF/KHU+mBt7+JpFXUIjmxJtEVk5eQOIWItTkCl9EvOSkBPJcJH3Hckdf3hScr+b/Pnrc452DXP2LSCNX+CLi+VLu6K0G3tXVfEFhkcuv2VZQyFNZq3hkYu7JXf0fOQK//+7T9yZEIEnCFxEvIzWFl69uS0pSAgpISUrg5avblrna9vam4GrIx51aCfGMX/YX2ulxd0NEWTl5pI+YS/PMGXR6aTbrOnaH1q358MJrSH/5exkiEiETkCEdpdQnQB9gl9a6jYvnFfA20Bs4DNymtf4pEOcWkSHYwx+O5Y72cz0yMbfcudzF4Gt7g4T4OJSiXLK3c34d58nigd98xJm5S1jWpA2DfpxCXs16DD3Yl2e/+Y2Cw0Vefzb27y2voJA4pSjWmhQZThI+CtQY/qfAe8A4N8//AzjT9udC4P9sH0UMCGWXSFfnemRiLtl/7uWFjLZuz+duHqB2YjyJVSqXeZN4ZGKu2/M7Dx053jn0XrOI+5Z9xfgOvXiqx318PPl5npo7hl8anUVucsvSeN39bJy/t2KtvX6NEI4CMqSjtV4A7PVwyFXAOG0sA5KUUo0CcW4R/nytkHEc+kgfMfekhjpcnUsD45f95fH13A35PH3lOSzO7MamEVewOLMbGakppUn9lKN/8/ScD2m6bzsAyvY6juxX/Gflb2bkzLdYmXw2z3YfhFaVGHLFo+w85VTezxpB2tbfuGRTDmjtdmjI07CTdN0UvgjVGH4KsMXh8622x8pQSt2jlMpWSmXn5+eHKDQRbL5WyASiBNLduTR4TIi+zAPYDe3ZkobHD/PPZV9x+8pv+O4/g2l0YDc3XtTUHF9cDFu2wKJF3LZpEfcvmcjoyc/zd5UE7s0YxrHK8QDsTziF+67KpO7hfXw1/gk+nzSct6a/RtXjx1x+H96GnVzdoZyMQLzxivAUqrJM5eKxcsOgWuuPgI8A0tLS3A2TigjjS9mkL4un3HGcH6hkG9d2pWhrHgweDNdfD506lXveY9uDbdtg4UJYsICMBQvI+PVX85qV4tCV4pgzdTjV17SARzebZH/8OABP2758a816/DPjX+w65dQyL7uq0Zncct3zNDi0h9P2bWfIovEkH8jnmdteKBeCu5+jXZxy9WtWMbJJS3QL1RX+VqCJw+eNgW0hOrewmC9lkyfbK975zsBdsu+yIZtZnz4I778PXbvC66+D87EbNsBNN5k3hG0O/z0ffhhSUmDAABg3DpKT4YUXYMEC4g//TfXZ31G9/qlw9Ch07AhDh8KHH8J338GaNUxbso7rh00gp3ErUpISuOmipqV3EkkJ8fzUvB3TWnfh3fSBDO77OO23r2PS2Edh3TqPP8fKxcepd+jESKq7793xZ+Xtyl02aYluobrCnwYMVkpNwEzW7tdabw/RuYXFfOkS6ctdgCveyinji4t4bMFnDPpxCvtbnA2ffQuvvgqPPQbLlsEnn5hE/fzz8H//B/Hx5o1g1ix480247TaYNw/S0szzHTpAZadfm06dICfHbQx9gb4dW7h93vEOJadjT5ZndKTz0Lvgootg6lS45BLgxM9xyKSfKdaaN6e/zpVrFnLuA+PZm1iLFA8/K1+v3GWTlugWqLLML4EuQF2l1FbMnWw8gNZ6FDATU5K5HlOWeXsgzisih7cukUN7tiyTkMC3XvGeElFaSQHDxz9L++3r2Nj/Fk4fOwoSEmDyZHjtNcjMhJUrYc8eOHQI7roLnnnG/P3OO+GOO2DiRNi0CW65xSR9J4EoN3X5s+nUBq64Arp3h/Hj4dprS48FmD7yU65csxCAW1dOZ1S3Wzz+rHwdMjvZN14RGQKS8LXWA708r4H7A3EuEZ1Otle8Y4KqVXiQ/dVqgFK0jj/GV58/BYUFMHkyp1999YkvUsoMu5x/vknkl14KI0ZA69Ynjpk/Hz74wLwp/P03NG1a7txBHe9u0QKWLIG+feGGG6BOHejWzbx2w0r0mv0OGxs0Iy+hNrflTOf0kc9ypYdz+nrlfrJvvCIyKO1l3M8qaWlpOjs72+owRJizJ924QwdZ/v4tvN/xOj7tdD3zZ79M/Z9XmInW888/+RNs2mSGdh59FJo1K/NU+oi5Lq+GU5ISWJzZzef4Pb7JFRRAejrk5cGiRVC/PnTpAn/9Zd4QDh40Qz7vvAMPPOD2PPZYqxYdJb6kmENVE93GKj2CIptSaqXWuvztKNI8TUQ4eyKa8XEW1YuOcP/yr+jcrwv1f1xkkqA/yR6geXPzOi74O97t0x1CUhJ8+60Zz+/d23y+ebN5rF07c8zFF8Mbb0CbNrBrl/mzc+eJv+/axawt22DXLmocK6QExZLT2jG9XXcuefyecnHJJi3RS67wRXQYO9ZMsIJJiocOQX6++XuQ+HuFX6Gvz8mBzp2hqAimT4fLLjvx3NSpkJFR9vhKlaBePXNHYPuzgUTm7NEcP3iIvmsW0mTfdg5Urc7ETv1p9O8n6NPpbJ++bxHe5ApfRL81a0yFzYAB8NlnJiEGMdmD/+PdFbpDSE2FxYtNff+555Z9rm9fmD3bVA/ZE3ydOibpOzjD9icrJ48ek3+h9eZfuefHKdw951P2LZ7CvDGT6Dqwp0+xi8gkCV9EhzVr4MwzTX381KlmotMLX8aqPR3j76bkFa6IsQ/hOFOq7BW/FyNnraXweAkrG7dmUOPWtNmxntGTn6PRg4PosmEUD1/hvueQiGyS8EVA+ZtET9qaNabKpmlTM25dpYrXOL2Nn/tyjD/j3VZVxDjfQfzasAWZvR7g06+epf+MTxh29A5AVtZGI+mHLwLGl344Ad828PBhePBBk/BTU81jVauaq14PfFlRGuxVpxXp3xNIru4g5p9xPhPbXs4/l0/mrD9X8/0HE6B6dbNGQUQNucIXAePL4h6vx+zfbxJ2tWreT7h8uamj/+MPk/SHDPE5Vl/Gz0Ox6tTTHUKwyiNd3VkAvND9Li7ZnMvrM94kr1Z982b6/fdw3XV+n1OEB7nCFwHjdxI9cAAaNYLERDjjDLPSdMgQGD3a1KDv3m0OPnYMnnzSlCMeOWKS0ttvm1W0PvK25aGvxwSLqzuhRybm0iwAHSwd7ywcHaxanSf+8SAt9m7l0k22/Yk2bvTjuxDhRhK+CBi/k+iuXVBYCH36mPr5vDyz2vXuu02/mnr1oG5dswr1pZfg1lvhl19KV6BWhC8N3Xw5Jljc9fWHAAyDYZL+4sxuvHV9hzLf46LmqUw4t/eJAyXhRxVJ+CJg/E6iBw6YB+66CyZMgNxcU0+/cSPMnAlvvMGmS3uxJDGZu68eTvpZN5G
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-4, 4, 100)\n",
"y = 1 - (x/3) ** 2\n",
"x1 = x + np.random.randn(100) * 0.1\n",
"y1 = y + np.random.randn(100) * 0.1\n",
"plt.scatter(x1,y1)\n",
"idx = x1.argsort()\n",
"beta,avg,res = 0.7,0,[]\n",
"for i in idx:\n",
" avg = beta * avg + (1-beta) * y1[i]\n",
" res.append(avg/(1-beta**(i+1)))\n",
"plt.plot(x1[idx],np.array(res), color='red');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtEAAAHiCAYAAAAuz5CZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iTVRvH8e9pKTRllQ0tG9kgFCpDcLBBVkV8AVFBRUDEiSjgHkABxYELRMEByrSioAwBlVFZBZElG9qyocxC13n/eJKSpkm60qRp78919Wrz5El6CumvJ+c55z5Ka40QQgghhBAi83w83QAhhBBCCCG8jXSihRBCCCGEyCLpRAshhBBCCJFF0okWQgghhBAii6QTLYQQQgghRBZJJ1oIIYQQQogskk608ApKqSNKqY6ebocQQoiMSWaLgkA60SJfy60gV4ZJSqlz5o/JSinl4Ny7lVIpSqkrVh+DXN0mIYTwdnkks8fZ5HW8OcPLurpdwrsV8nQDhPBSQ4EwoAmggZXAIeBzB+fHaq0ru6ltQggh0sp0ZmutJwATLLeVUm8Ad2qtz7qlpcJryEi08Ca3KaV2K6UuKKVmKaX8AZRSPZRS25VScUqpDUqpW83HvwWqAj+bRxNeNB9foJQ6qZS6qJT6UynVMBttGQS8p7WO1lrHAO8Bg13yUwohRP7g9ZltHq1+CPg6G99T5HPSiRbeZCDQBagF1AFeUUo1A74ChgFlgOnAEqVUEa31Q8AxoKfWupjWerL5eX4FagPlgW3AHMs3UEqNMQe73Q+rtjQEdljd3mE+5kh5pdQppdRhpdT7SqmiOfh3EEIIb+DNmW1xB1ABWJTln17ke0pr7ek2CJEhpdQRIFxr/bn59j3ANGAFcFZr/arVufuAoVrrP8yPG6K1XuXgeQOBC0Cg1vpiFtqTDDTUWu81364N/Af4aJtfKqVURaA0sBeohjGisUdrPSyz308IIbyJN2e2zeO+BHy11oMz+71EwSEj0cKbHLf6+igQhNEpHWUz8lDFfF86SilfpVS4UuqgUuoScMR8V1YXjFwBSljdLgFcsRfGWuuTWuvdWusUrfVh4EWgbxa/nxBCeBuvzGyr720C7kemcggHpBMtvEkVq6+rArEYIT1eax1o9RGgtf7efJ5tQD4A9AY6AiWB6ubjCuyuyk7zYfU8uzAWqFg0MR/LDG35fkIIkY95e2b3Ac4DazPzw4qCRzrRwps8qZSqrJQqDYwD5gFfAMOVUi2VoahSqrtSqrj5MaeAmlbPURy4AZwDArBagQ3GqmzzXDy7H1anfgM8r5QKVkoFAaOA2fYarYwSd1XN7asChAM/5fQfQwgh8jivzGwrg4BvnI1Wi4JNOtHCm8zFmE93yPzxjtZ6C/A48DHGPLkDpF1xPRFjMUucUuoFjCA9CsQAu4HIbLZlOvAzsBP4F1hqPgaAeRTkDvPNZsBG4CqwwXz+09n8vkII4S28NbNRSgUD7c3fXwi7ZGGhEEIIIYQQWSQj0UIIIYQQQmSRdKKFEEIIIYTIIpd0opVSXymlTiul/nVw/0Cl1D/mjw1KqSZW9x1RSu1Uxu5FW1zRHiGEEI5JZgshRM65aiR6NtDVyf2Hgbu01rcCbwMzbO5vp7VuqrUOdVF7hBBCODYbyWwhhMiRQq54Eq31n0qp6k7u32B1MxKo7IrvK4QQIusks4UQIudc0onOoseAX61ua2CFUkoD07XWtiMe6ZQtW1ZXr149l5onhBC5Z+vWrWe11uU83Y4skMwWQhRYzjLbrZ1opVQ7jEBua3W4jdY6VilVHliplNqrtf7TzmOHAkMBqlatypYtMhVPCOF9lFJHPd2GzJLMFkIUdM4y223VOZRStwIzgd5a63OW41rrWPPn08CPQAt7j9daz9Bah2qtQ8uV86ZBHCGE8D6S2UII4ZxbOtFKqarAYuAhrfV/VseLWrb6VEoVBTpj7CQkhBDCQySzhRAiYy6ZzqGU+h64GyirlIoGXgf8ALTWnwOvAWWAT5VSAEnmVd0VgB/NxwoBc7XWv7miTUIIIeyTzBZCiJxzVXWOARncPwQYYuf4IaBJ+kcIIYTILZLZQgiRc7JjoRBCCCGEEFnkiRJ3QniViKgYpizfR2xcPEGBJkZ3qUtYSLCnmyWEEMIOyWzhLtKJFsKJiKgYxi7eSXxiMgAxcfGMXbwTQEJZCCHyGMls4U4ynUMIJ6Ys35caxhbxiclMWb7PQy0SQgjhiGS2cCfpRAvhRGxcfJaOCyGE8BzJbOFO0okWwomgQFOWjgshhPAcyWzhTtKJFsKJ0V3qYvLzTXPM5OfL6C51PdQiIYQQjkhmC3eShYVCOGFZiOJopbesAhdCiLxDMlu4k3SihchAWEiw3ZCVVeBCCJH3SGYLd5FOtBDZ5GwVeFYCWUZGhBAi90lmC1eTTrQQ2eSKVeAyMiKEEO4hmS1cTRYWigIpIiqGNuGrqTFmKW3CVxMRFeP8AbNmwe7daQ65YhW41DQVQojMyXJu25DMFq4mnWhR4FhGEmLi4tHcHElwGMjR0fDoo9CkCbzwAly6BLhmFbjUNBVCiIxlObftkMwWruaSTrRS6iul1Gml1L8O7ldKqY+UUgeUUv8opZpZ3TdIKbXf/DHIFe0RBVdmRiqyPJKQmGh8rlcPpk41Ps+ZQ1jTICb2aUxwoAkFBAeamNincZYu6UlNU+EJktkir8js6LIrRoDDQoIls4VLuWpO9GzgY+AbB/d3A2qbP1oCnwEtlVKlgdeBUEADW5VSS7TWF1zULlGAZHauWpZHErQ2Pr/wAtSvD08+CQ8+CNOnE/bxx4SNaZ/tNo/uUjdNm0Fqmgq3mI1ktvCwrMwvdtUIsKPKHZklmS2suWQkWmv9J3DeySm9gW+0IRIIVEpVAroAK7XW580hvBLo6oo2iYInsyMVWR5JSEkxPvv4QIsW8PffMGOGMUe6SRMoXhyWLs1Wm10xMiJEVklmi7wgK6PLWc7t8eNhwAB45x1YvBj27YOkpBy3WTJbWHNXdY5g4LjV7WjzMUfHhXDKXomhzI5U2BtJUBijIG3CV6crV7Ry1wk6Ac/N38GmmGDj/scfZ1Xp2nTs2w6uXCGpVy92vhxOyFujM91ey/fI6ciIELlAMlu4VE4yG7KW2xFRMbSe8iGBV+Io8sMPqecn+xXmcJlg9pUMJrJZO0Kfe4zet1XLUpsteS2ZLcB9nWhl55h2cjz9Eyg1FBgKULVqVde1TOR5tkHWrl45Fm2NSXcJMDDAjwvXEtM93nakwnpHq5i4eBQ3X3S2lxMjomKYvnwfnYBkpVLv33L0PKs2nqGj+XG7ytci5O0X2XsylnrTp4JSqW1/Y8ku4uJvtismLp7n5m1ny9HzvBPW2DhomTKi7P1KCOF2ktkiR6xzu6TJj6sJSSQmGy+VrGY2ZD63AcYu3sma5GQWN2zHWx2G0vBiDP2LXSZu83aGbFzILScP033fOmKXfs6/gx+n0RsvQKlSadpt73uky21R4LmrOkc0UMXqdmUg1snxdLTWM7TWoVrr0HLlyuVaQ0XeYm9F9pzIY3YvAWptzE2rcPksShtTMBzNVQsLCWb9mPYEB5rS9QCsLydOWb6PhETLJUCVev/3fx/nrCqc+pgH+o9nfuOO1PviA3jsMThwgF9XbeeNeVuIu5aQ7vtrYE7ksZuLaEaMgHbtICH9uUJ4gGS2yDbb3I6LT0ztQFtYZ7Y1Z/OLM5PblikiPlqjlQ/xhf3ZUq4WLwWE8M6dg1PP/7RVXw6XqkSjjyZA5crQtSvH7+nDpWFP0veXL6l17ni675Eut0WB565O9BLgYfOK71bARa31CWA50FkpVUopVQrobD4mCoDsVtKwO+wFXIxPZFqzANZNH8KgbUszNVct9sI1msTuY8yar5iy9APKXI0zjpsvJ549e4mmsf+l+77JWpPke/NCztUiAbzY7Rk+aDPAqClduzbdOoWwfVIY/717L712/wGAb0oyQZdOpz7flOX74MQJmDkT/vgD3n7b2T+ZEO4imS3SyUklDXsuxidma36xs2kglvt
"text/plain": [
"<Figure size 864x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-4, 4, 100)\n",
"y = 1 - (x/3) ** 2\n",
"x1 = x + np.random.randn(100) * 0.1\n",
"y1 = y + np.random.randn(100) * 0.1\n",
"_,axs = plt.subplots(2,2, figsize=(12,8))\n",
"betas = [0.5,0.7,0.9,0.99]\n",
"idx = x1.argsort()\n",
"for beta,ax in zip(betas, axs.flatten()):\n",
" ax.scatter(x1,y1)\n",
" avg,res = 0,[]\n",
" for i in idx:\n",
" avg = beta * avg + (1-beta) * y1[i]\n",
" res.append(avg)#/(1-beta**(i+1)))\n",
" ax.plot(x1[idx],np.array(res), color='red');\n",
" ax.set_title(f'beta={beta}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def average_grad(p, mom, grad_avg=None, **kwargs):\n",
" if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)\n",
" return {'grad_avg': grad_avg*mom + p.grad.data}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def momentum_step(p, lr, grad_avg, **kwargs): p.data.add_(-lr, grad_avg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_func = partial(Optimizer, cbs=[average_grad,momentum_step], mom=0.9)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.856000</td>\n",
" <td>2.493429</td>\n",
" <td>0.246115</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.504205</td>\n",
" <td>2.463813</td>\n",
" <td>0.348280</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.187387</td>\n",
" <td>1.755670</td>\n",
" <td>0.418853</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit_one_cycle(3, 0.03)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt0AAAD4CAYAAAAwyVpeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3yV9fn/8deVTfYkjEz2JpAQQETcolVRBAGto9pqtWqdraNaZx1UqbXan7TiqBNwoYI4EBVkJIywAyEJJKwkZJKdnM/vjxz8pjRAgJzcZ1zPx+M8PLnPfd95H4STK/d9fT4fMcaglFJKKaWUchwvqwMopZRSSinl7rToVkoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysF8rA7QGaKjo01SUpLVMZRS6oStWbOmxBgTY3WOzqSf2UopV3Wsz2yPKLqTkpLIzMy0OoZSSp0wEdlldYbOpp/ZSilXdazPbG0vUUoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysG06FZKKaWUUsrBHFp0i8hEEckWkRwRub+N1/1F5AP766tEJMm+PV1E1tsfWSJyeXvPqZRSSimllLNxWNEtIt7Ay8CFwCBghogMOmK3G4EyY0wfYBbwrH37JiDNGJMCTAReFRGfdp5TKaWUUkopp+LIebrTgRxjTC6AiLwPTAK2tNpnEvCo/fl84B8iIsaYmlb7BADmBM6pOlh9UzMrc0vJ3l9JTUMzgX7eBPn7EOzvQ3igHwmRgfQM74Kfj3YrKaWsMevr7UQG+TEsLoyU+HBExOpISikXUVHTyP7KOg5U1lFUVU9pdT01Dc3MSE8gNjSgw76PI4vunkBBq68LgdFH28cY0yQiFUAUUCIio4E5QCJwjf319pwTABG5CbgJICEh4dTfjQdqthneWbWLv32zg9LqhmPu6yXQPawL/buFMCI+nBEJEQyPDyMkwLeT0iqlPFVjs43ZP+RS29gMwIBuITx66WDG9IqyOJlSytkUVdWxMreUzXsr2LK3kq37qig5VN/mvuP7xrhM0d3WZQbT3n2MMauAwSIyEHhTRBa185zYj58NzAZIS0trcx91dBW1jdz5/jq+yy5mbK8ofnNGMqmJkQT7+1Db2Ex1fRPV9U2UHGpgd2kNuw9Ws6u0hs17K1myrQgAEUiJD+fcgbGcNyiWvl2D9eqTUqrD+Xp7seXxCyiuqmfJtiJeXprD9Nkruff8fvzurD76uaOUB2tqtrEi9yDfZxezLKeEbfurAPDz9qJvbDBn9o+hf2wI3cMDiA0NoGuIP5FBfnTx9cbHu2Pv4Duy6C4E4lt9HQfsPco+hSLiA4QBpa13MMZsFZFqYEg7z6lOUWVdI9e8toqt+yp54rIh/HJ0wn/90Aq2t5YA9IqB9OTI/zq+oraRrIJyMneVsTS7iJmLs5m5OJuEyEAuS+nB1LR44iMDO/U9KaXcm4jQNTSA6ekJXJrSg4c+3sRfv9pOVV0TD1w00Op4SqlOZIxh7e5yFqzfwxcb91FyqAE/Hy9GJUXwx4kDOL1PNAO6h+DbwUX18Tiy6M4A+opIMrAHmA5cdcQ+C4DrgBXAFGCJMcbYjymwt5QkAv2BfKC8HedUp6Cx2cZNb2WyZW8lr16TyjkDY0/4HGFdfDmjXwxn9Ivh7vP6caCyjm+3FrFo0z5e+i6Hl77LYVzvaKaNimfikG6d/pdeKeXeAv18eOHK4YQE+PDqD7l0Dwvg+nHJVsdSSjlYdX0TH64t5I3l+eSWVOPv48W5A2O5ZHgPJvSLoYuft6X5HFZ02wvm24DFgDcwxxizWUQeBzKNMQuA14D/iEgOLVe4p9sPPx24X0QaARtwqzGmBKCtczrqPXiiZxdtY2VuKS9cOfykCu62xIYGcNXoBK4ancCe8lrmZRYwL7OQ299bR1xEF246oxdTU+Mt/8eglHIfIsKfLxnMvoo6nvhiK8PiwxmZEGF1LKWUA+wtr+X15Xm8n1FAVV0Tw+PDmTllGBOHdHOqsWVijPu3O6elpZnMzEyrYzi9H7YXc+2c1Vw3NpHHJg1x6Pey2Qzfbivin0tzWLu7nKggP244PZkbxiVr8a1UKyKyxhiTZnWOztSRn9mVdY1c9OKPiMCXvz+DIH9H3uBVSnWmoso6Xlm6k3dX7abZGC4c0o0bTk+29BfsY31m66ePAqCqrpH7P9xA75igTul/9PISzhsUy7kDu5KRX8YrS3OYuTibt1bkc/d5/ZiSGo+3lw5+UkqdmtAAX2ZNS2Hq/1vB377ZzkO/0KUdlHJ1FTWNvLI0hzdX5NPYbJiaGsdtZ/chLsK5x4tp0a0AeGlJDvsq65j/29MI8O28K80iQnpyJOnJ6WTkl/KXhVv544cbmbMsn4d+MZAz+sV0WhallHsalRTJ9FHxzFmezxWpcQzoFmp1JKXUSbDZDHMzC3hucTZlNQ1cltKT35/Tl6ToIKujtYuOYFPklVTz+vI8poyMIzXRulsyo5Ii+eiW03jl6pHUNTVz7ZzV3PHeOoqr2p4/Uyml2uuPEwcQ7O/D0wu3WR1FKXUSsgrKufyV5dz/0UZ6xwTxxe3jmTUtxWUKbtCiWwEzF2/Dz9uL+yb2tzoKIsJFQ7vz1V1ncOe5ffly037OeX4p76/ejc3m/uMPlFKOERHkx61n9ub77cX8tLPE6jhKqXaqa2zmic+3cNkry9lbUcesacOZe/NYBvVwvTtWWnR7uM17K1i4cT83np5M15COW3XpVPn7eHPnuf1Y+PvxDOweyv0fbeSaOavYV1FrdTSllIu67rQkuocF8PxX2/GESQSUcnVrd5dx0Ys/8tqyPK5KT2DJPRO4fEScyy54pUW3h/vbNzsIDfDhxvG9rI7Spj5dg3n/pjE8PXko63aXM/FvP/LFhn1Wx1JKuaAAX29+O6E3a3aVsTqv9PgHKKUs0dBk45lF25jyz5+ob7Lx9o2jeeryoU41/d/J0KLbg+UWH+LrLQe4flwyYV2c9y+yiDAjPYGFd4wnOTqI3727lrvnrudQfZPV0ZRSLmbaqHiig/34x3c5VkdRSrWhoLSGqa+u4P99v5Mr0+L58s7xnN432upYHUKLbg/2+vJ8/Ly9uGZMotVR2iUpOoh5vx3L78/pyyfr9nDpP5ax40CV1bGUUi4kwNebX41L5scdJeQU6eeHUs7kq837+cXffyS36BCvXD2SZ64Y5vJXt1vTottDVdQ0Mn9NIZem9CAmxN/qOO3m6+3FXef1451fj6GytpFJLy/ns6y9VsdSSrmQaaPi8fP24u2Vu62OopQCmpptPPn5Fm76zxoSo4L4/I7TuWhod6tjdTgtuj3Uexm7qW1s5oZxyVZHOSlje0fx+e0tgyxvf28dj3+2haZmm9WxlFIuIDrYnwuHduPDNYXUNGibmlJWqqhp5FdvZPDvZXlcOzaR+beMJTHKdaYBPBFadHugxmYbb/6Uz9heUS455c5h3cICeO83Y7j+tCTmLM/jV29kUFnXaHUspZQLuGZMIlX1TXy6Xu+UKWWVnKIqJr28jJW5B3nuimE8PmkI/j6dt0BfZ9Oi2wN9veUA+yrquPF017zK3ZqfjxePXjqYZyYPZcXOg1zxyk8UlNZYHUsptyIiE0UkW0RyROT+Nl5PFJFvRWSDiCwVkbgjXg8VkT0i8o/OS31sqYkRDOgWwn9W7NLpA5WywNLsIi5/+ScO1Tfx3m/GcOWoeKsjOZwW3R5oXmYB3cMCOGtAV6ujdJjp6Qm8dUM6ByrruPyV5azdXWZ1JKXcgoh4Ay8DFwKDgBkiMuiI3f4KvGWMGQY8Djx9xOtPAN87OuuJEBF+OSaRLfsqWVdQbnUcpTzK3IwCbnwzk/jIQD697XTSkiKtjtQptOj2MEWVdXy/vZjJI3vi7eWak8sfzWl9ovno1nEE+fswY/ZKvtlywOpISrmDdCDHGJNrjGkA3gcmHbHPIOBb+/PvWr8uIqlALPBVJ2Q9IZeN6EmgnzdzMwqsjqKURzDG8OI3O/jDhxs4rXcUc387lp7hXay
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_sched()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RMSProp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
" return {'sqr_avg': sqr_avg*sqr_mom + p.grad.data**2}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):\n",
" denom = sqr_avg.sqrt().add_(eps)\n",
" p.data.addcdiv_(-lr, p.grad, denom)\n",
"\n",
"opt_func = partial(Optimizer, cbs=[average_sqr_grad,rms_prop_step],\n",
" sqr_mom=0.99, eps=1e-7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.766912</td>\n",
" <td>1.845900</td>\n",
" <td>0.402548</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.194586</td>\n",
" <td>1.510269</td>\n",
" <td>0.504459</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.869099</td>\n",
" <td>1.447939</td>\n",
" <td>0.544968</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit_one_cycle(3, 0.003)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adam"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Decoupled weight_decay"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Callbacks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a callback"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ModelReseter(Callback):\n",
" def begin_train(self): self.model.reset()\n",
" def begin_validate(self): self.model.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RNNRegularizer(Callback):\n",
" def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta\n",
"\n",
" def after_pred(self):\n",
" self.raw_out,self.out = self.pred[1],self.pred[2]\n",
" self.learn.pred = self.pred[0]\n",
"\n",
" def after_loss(self):\n",
" if not self.training: return\n",
" if self.alpha != 0.:\n",
" self.learn.loss += self.alpha * self.out[-1].float().pow(2).mean()\n",
" if self.beta != 0.:\n",
" h = self.raw_out[-1]\n",
" if len(h)>1:\n",
" self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]\n",
" ).float().pow(2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Callback ordering and exceptions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TerminateOnNaNCallback(Callback):\n",
" run_before=Recorder\n",
" def after_batch(self):\n",
" if torch.isinf(self.loss) or torch.isnan(self.loss):\n",
" raise CancelFitException"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. What is the equation for a step of SGD, in math or code (as you prefer)?\n",
"1. What do we pass to `cnn_learner` to use a non-default optimizer?\n",
"1. What are optimizer callbacks?\n",
"1. What does `zero_grad` do in an optimizer?\n",
"1. What does `step` do in an optimizer? How is it implemented in the general optimizer?\n",
"1. Rewrite `sgd_cb` to use the `+=` operator, instead of `add_`.\n",
"1. What is momentum? Write out the equation.\n",
"1. What's a physical analogy for momentum? How does it apply in our model training settings?\n",
"1. What does a bigger value for momentum do to the gradients?\n",
"1. What are the default values of momentum for 1cycle training?\n",
"1. What is RMSProp? Write out the equation.\n",
"1. What do the squared values of the gradients indicate?\n",
"1. How does Adam differ from momentum and RMSProp?\n",
"1. Write out the equation for Adam.\n",
"1. Calculate the value of `unbias_avg` and `w.avg` for a few batches of dummy values.\n",
"1. What's the impact of having a high eps in Adam?\n",
"1. Read through the optimizer notebook in fastai's repo, and execute it.\n",
"1. In what situations do dynamic learning rate methods like Adam change the behaviour of weight decay?\n",
"1. What are the four steps of a training loop?\n",
"1. Why is the use of callbacks better than writing a new training loop for each tweak you want to add?\n",
"1. What are the necessary points in the design of the fastai's callback system that make it as flexible as copying and pasting bits of code?\n",
"1. How can you get the list of events available to you when writing a callback?\n",
"1. Write the `ModelResetter` callback (without peeking).\n",
"1. How can you access the necessary attributes of the training loop inside a callback? When can you use or not use the shortcut that goes with it?\n",
"1. How can a callback influence the control flow of the training loop.\n",
"1. Write the `TerminateOnNaN` callback (without peeking if possible).\n",
"1. How do you make sure your callback runs after or before another callback?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Look up the \"rectified Adam\" paper and implement it using the general optimizer framework, and try it out. Search for other recent optimizers that work well in practice, and pick one to implement.\n",
"1. Look at the mixed precision callback with the documentation. Try to understand what each event and line of code does.\n",
"1. Implement your own version of ther learning rate finder from scratch. Compare it with fastai's version.\n",
"1. Look at the source code of the callbacks that ship with fastai. See if you can find one that's similar to what you're looking to do, to get some inspiration."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Foundations of Deep Learning: Wrap up"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Congratulations, you have made it to the end of the \"foundations of deep learning\" section. You now understand how all of fastai's applications and most important architectures are built, and the recommended ways to train them, and have all the information you need to build these from scratch. Whilst you probably won't need to create your own training loop, or batchnorm layer, for instance, knowing what is going on behind the scenes is very helpful for debugging, profiling, and deploying your solutions.\n",
"\n",
"Since you understand all of the foundations of fastai's applications now, be sure to spend some time digging through fastai's source notebooks, and running and experimenting with parts of them, since you can and see exactly how everything in fastai is developed.\n",
"\n",
"In the next section, we will be looking even further under the covers, to see how the actual forward and backward passes of a neural network are done, and we will see what tools are at our disposal to get better performance. We will then finish up with a project that brings together everything we have learned throughout the book, which we will use to build a method for interpreting convolutional neural networks."
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}