fastbook/clean/16_accel_sgd.ipynb

708 lines
130 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The training process"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's start with SGD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_data(url, presize, resize):\n",
" path = untar_data(url)\n",
" return DataBlock(\n",
" blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, \n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_y=parent_label, item_tfms=Resize(presize),\n",
" batch_tfms=[*aug_transforms(min_scale=0.5, size=resize),\n",
" Normalize.from_stats(*imagenet_stats)],\n",
" ).dataloaders(path, bs=128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = get_data(URLs.IMAGENETTE_160, 160, 128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_learner(**kwargs):\n",
" return cnn_learner(dls, resnet34, pretrained=False,\n",
" metrics=accuracy, **kwargs).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.571932</td>\n",
" <td>2.685040</td>\n",
" <td>0.322548</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.904674</td>\n",
" <td>1.852589</td>\n",
" <td>0.437452</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.586909</td>\n",
" <td>1.374908</td>\n",
" <td>0.594904</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner()\n",
"learn.fit_one_cycle(3, 0.003)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = get_learner(opt_func=SGD)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.017378008365631102, 3.019951861915615e-07)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhc9X3v8fdXI412a7HlfTfGYDYbGxNKYxxCEhIIJCHNpU3a0NBCkiYkN0/TNM29tOU+Cbmla5qnpQQuoQFSEghroEAbtkAM2JjF7LaxZEm2JWvfpdF87x8zkoUsyZLRmTmj+byeZx7PnDkz58NIzFe/81uOuTsiIpK9ctIdQERE0kuFQEQky6kQiIhkORUCEZEsp0IgIpLlVAhERLJcbroDTNWcOXN8+fLl6Y4hIpJRtm/ffsjdq8Z6LuMKwfLly9m2bVu6Y4iIZBQzqx7vOZ0aEhHJcioEIiJZToVARCTLqRCIiGQ5FQIRkSynQiAikuVUCEREMsCjrx1kV0NnIO+tQiAiEnLuzpdv285dL9QG8v4qBCIiIdfRF2Ng0Kksigby/ioEIiIh19LVD0BFsQqBiEhWakoWgtkqBCIi2UktAhGRLKcWgYhIllOLQEQkyzV39RPNzaE4Ggnk/VUIRERCrrmrn8qiKGYWyPurEIiIhFxzVz+VAZ0WAhUCEZHQa+5WIRARyWpqEYiIZDkVAhGRLDYwGKejN6ZCICKSrYKeQwAqBCIiodbcHeysYlAhEBEJtebOZIsgoCWoQYVARCTUhlsEJSoEIiJZqblLLQIRkax2uBDkBXYMFQIRkRBr7uqnrDCP3EhwX9cqBCIiIdbc1R/oiCFQIRARCbXmrv5A5xCACoGISKgFvbwEqBCIiITa0LUIgqRCICISUu5OS3c/lQHOIQAVAhGR0OrsizEw6GoRiIhkq+YULDgHkBvkm5vZXqADGARi7r5x1PNbgHuBd5KbfuHu1wSZSUQkUwwVgqCHjwZaCJI+4O6HJnj+KXe/MAU5REQySqpaBDo1JCISUqlqEQRdCBx4xMy2m9kV4+xzlpm9ZGYPmdlJAecREckYM6KPADjb3evNbC7wqJm94e5Pjnj+BWCZu3ea2ceAe4DVo98kWUSuAFi6dGnAkUVEwqG5u59obg7F0Uigxwm0ReDu9cl/G4C7gU2jnm93987k/QeBPDObM8b73ODuG919Y1VVVZCRRURCo7kzMZnMzAI9TmCFwMyKzax06D7wYWDnqH3mW/K/0Mw2JfM0BZVJRCSTtHQHv7wEBHtqaB5wd/J7Phe43d3/08y+CODu1wOfBr5kZjGgB7jU3T3ATCIiGaMpBesMQYCFwN33AKeNsf36Efd/CPwwqAwiIpmspaufJRVFgR9Hw0dFREIqVS0CFQIRkRAaGIzT0RtTIRARyVYtKZpDACoEIiKh1NydmlnFoEIgIhJKzZ3JFkHAS1CDCoGISCg1dPQBUFWaH/ixVAhERELoYHsvAPPLCgI/lgqBiEgIHWjvpTgaoSQ/+KsFqBCIiIRQQ3sf81LQGgAVAhGRUDrQ3sv8WSoEIiJZ62B7L/NUCEREspO7J04NqRCIiGSn5q5++gfjzJsV/NBRUCEQEQmdg+2JOQTqIxARyVJDcwjmqhCIiGSnVE4mAxUCEZHQOZAsBFUl6iMQEclKB9t7mVMSJZqbmq9oFQIRkZA5mMKho6BCICISOgfaUjeZDFQIRERCp6FDhUBEJGv1x+Ic6uxP2WQyUCEQEQmVxs7UTiYDFQIRkVA50JYYOqpTQyIiWWpoMpkKgYhIljpcCNRHICKSlQ609xKN5FBZHE3ZMVUIRERCpKG9j7mz8jGzlB1ThUBEJERSPZkMVAhERELlYAqvVTxEhUBEJEQOtvcyN4UdxaBCICISGh29A3T1D6pFICKSrYYuUak+AhGRLJWOyWQQcCEws71m9oqZvWhm28Z43szsB2a2y8xeNrPTg8wjIhJmQ8tLpOoSlUNyU3CMD7j7oXGe+yiwOnk7E/jX5L8iIlnnYEfqZxVD+k8NXQz8uydsBcrNbEGaM4mIpMXBtl5K83Mpiqbib/TDgi4EDjxiZtvN7Ioxnl8E7BvxuDa57V3M7Aoz22Zm2xobGwOKKiKSXnWtvSwsL0z5cYMuBGe7++kkTgH9iZltHvX8WHOo/YgN7je4+0Z331hVVRVEThGRtKtv7WFRxQwrBO5en/y3Abgb2DRql1pgyYjHi4H6IDOJiIRVXWsPC8tT21EMARYCMys2s9Kh+8CHgZ2jdrsP+IPk6KH3AW3uvj+oTCIiYdXZF6OtZ4BF5UUpP3aQPRLzgLuTK+jlAre7+3+a2RcB3P164EHgY8AuoBv4wwDziIiEVn1rD0BaWgSBFQJ33wOcNsb260fcd+BPgsogIpIp6pKFYPFM6yMQEZHJqWsZahGoEIiIZKX61h5yc4y5pTOos1hERCavrrWHBeUFRHJSd2WyISoEIiIhUNfSw8Ky1J8WAhUCEZFQSNdkMlAhEBFJu4HBOAfae1mUho5iUCEQEUm7g+29xB0VAhGRbJXOoaOgQiAiknb1bYlCoD4CEZEsNdwi0KghEZHsVNfay+ziKIXRSFqOP6lCYGarzCw/eX+LmV1lZuXBRhMRyQ51aRw6CpNvEdwFDJrZccBNwArg9sBSiYhkkfrW9E0mg8kXgri7x4BPAv/o7v8T0LWFRUTeI3enriUzWgQDZva7wOeBB5Lb8oKJJCKSPVq7B+gZGEzb0FGYfCH4Q+As4Lvu/o6ZrQBuDS6WiEh2GLoOQbomk8EkL0zj7q8BVwGYWQVQ6u7fDzKYiEg2CEMhmOyoocfNbJaZVQIvATeb2d8HG01EZOYbmkOQCX0EZe7eDnwKuNndNwDnBRdLRCQ71Lf2UJCXQ0VR+rpdJ1sIcs1sAfAZDncWi4jIe1TX2sOi8kLMUn9BmiGTLQTXAA8Du939eTNbCbwdXCwRkeyQmExWlNYMkyoE7v5zdz/V3b+UfLzH3S8JNpqIyMzm7uw91MXiNPYPwOQ7ixeb2d1m1mBmB83sLjNbHHQ4EZGZbM+hLtp7Y5y2uCytOSZ7auhm4D5gIbAIuD+5TUREjtGOmlYA1i+tSGuOyRaCKne/2d1jyduPgaoAc4mIzHg7aloozc/luKqStOaYbCE4ZGafM7NI8vY5oCnIYCIiM92OmlbWLS0nJyd9I4Zg8oXgCySGjh4A9gOfJrHshIiIHIOuvhhvHGhn/ZL0r+g/2VFDNe5+kbtXuftcd/8EicllIiJyDF6ubSPu6e8fgPd2hbJvTFsKEZEss2NfCwDrMqVFMI70ntQSEclgO2paWTmnmIriaLqjvKdC4NOWQkQki7g7O2paWLc0/a0BOMoy1GbWwdhf+AakdyqciEiGqm3p4VBnfyj6B+AohcDdS1MVREQkW7xQk+gfOD0kLYL3cmpoUpLzDnaY2RGrlprZZWbWaGYvJm9/FHQeEZF021HTSmFehDXzwvG39qSuUPYefQ14HZg1zvN3uPtXUpBDRCQUdtS0cOriMnIjgf8tPimBpkguTHcBcGOQxxERyRS9A4O8Wt8emv4BCP7U0D8CfwbEJ9jnEjN72czuNLMlY+1gZleY2TYz29bY2BhIUBGRVHjjQAexuLNuSXpXHB0psEJgZhcCDe6+fYLd7geWu/upwH8Bt4y1k7vf4O4b3X1jVZXWuhORzFXd1AXAyjQvNDdSkC2Cs4GLzGwv8B/AuWZ268gd3L3J3fuSD38EbAgwj4hI2lU3dQOwtDK9VyUbKbBC4O7fdvfF7r4cuBT4lbt/buQ+yesgD7mIRKeyiMiMVd3UzbxZ+RTkRdIdZVgqRg29i5ldA2xz9/uAq8zsIiAGNAOXpTqPiEgq1TR3sayyON0x3iUlhcDdHwceT96/esT2bwPfTkUGEZEwqG7qZvPx4errDMcgVhGRLNDTP0hDRx/LQtQ/ACoEIiIpU9Oc7CierUIgIpKVhguBWgQiItlpaA7Bstnh6ixWIRARSZGa5m5K83OpKMpLd5R3USEQEUmR6qZuls4uwixcF3h
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.969412</td>\n",
" <td>2.214596</td>\n",
" <td>0.242038</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.442730</td>\n",
" <td>1.845950</td>\n",
" <td>0.362548</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.157159</td>\n",
" <td>1.741143</td>\n",
" <td>0.408917</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(3, 0.03, moms=(0,0,0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A generic optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sgd_cb(p, lr, **kwargs): p.data.add_(-lr, p.grad.data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_func = partial(Optimizer, cbs=[sgd_step])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.730918</td>\n",
" <td>2.009971</td>\n",
" <td>0.332739</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.204893</td>\n",
" <td>1.747202</td>\n",
" <td>0.441529</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.875621</td>\n",
" <td>1.684515</td>\n",
" <td>0.445350</td>\n",
" <td>00:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit(3, 0.03)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Momentum"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hU1dbA4d8mBEhACEhNAEFRBGnRWDCCFAUuIgYVBXvnqlgRDVe5dkWxtw8VvIKiwBUMCCgXBKQjwURRBKkKoYUSigQIyf7+2DNhMpkWppwp630enpCZkzkrgaw5Z++111Zaa4QQQkS/SlYHIIQQIjQk4QshRIyQhC+EEDFCEr4QQsQISfhCCBEjKlsdgDt169bVzZo1szoMIYSIKCtXrtytta7n6rmwTfjNmjUjOzvb6jCEECKiKKX+dPecDOkIIUSMkIQvhBAxQhK+EELECEn4QggRIyThCyFEjJCEL4QQMUISvhBCxIiwrcMXIhpl5eQxctZathUUkpyUwNCeLclITbE6LBEjJOELESJZOXkMm7KKwqJiAPIKChk2ZRWAJH0REjKkI0SIjJy1tjTZ2xUWFTNy1lqLIhKxRhK+ECGyraCwQo8LEWiS8IUIkeSkhAo9LkSgScIXIkSG9mxJQnxcmccS4uMY2rOlRRGJWCOTtkKEiH1iVqp0hFUk4YuoECnljhmpKWEZl4gNkvBFxPO13DFS3hSECBZJ+CLieSp3tCf0k62BlzcJEU1k0lZEPF/KHU+mBt7+JpFXUIjmxJtEVk5eQOIWItTkCl9EvOSkBPJcJH3Hckdf3hScr+b/Pnrc452DXP2LSCNX+CLi+VLu6K0G3tXVfEFhkcuv2VZQyFNZq3hkYu7JXf0fOQK//+7T9yZEIEnCFxEvIzWFl69uS0pSAgpISUrg5avblrna9vam4GrIx51aCfGMX/YX2ulxd0NEWTl5pI+YS/PMGXR6aTbrOnaH1q358MJrSH/5exkiEiETkCEdpdQnQB9gl9a6jYvnFfA20Bs4DNymtf4pEOcWkSHYwx+O5Y72cz0yMbfcudzF4Gt7g4T4OJSiXLK3c34d58nigd98xJm5S1jWpA2DfpxCXs16DD3Yl2e/+Y2Cw0Vefzb27y2voJA4pSjWmhQZThI+CtQY/qfAe8A4N8//AzjT9udC4P9sH0UMCGWXSFfnemRiLtl/7uWFjLZuz+duHqB2YjyJVSqXeZN4ZGKu2/M7Dx053jn0XrOI+5Z9xfgOvXiqx318PPl5npo7hl8anUVucsvSeN39bJy/t2KtvX6NEI4CMqSjtV4A7PVwyFXAOG0sA5KUUo0CcW4R/nytkHEc+kgfMfekhjpcnUsD45f95fH13A35PH3lOSzO7MamEVewOLMbGakppUn9lKN/8/ScD2m6bzsAyvY6juxX/Gflb2bkzLdYmXw2z3YfhFaVGHLFo+w85VTezxpB2tbfuGRTDmjtdmjI07CTdN0UvgjVGH4KsMXh8622x8pQSt2jlMpWSmXn5+eHKDQRbL5WyASiBNLduTR4TIi+zAPYDe3ZkobHD/PPZV9x+8pv+O4/g2l0YDc3XtTUHF9cDFu2wKJF3LZpEfcvmcjoyc/zd5UE7s0YxrHK8QDsTziF+67KpO7hfXw1/gk+nzSct6a/RtXjx1x+H96GnVzdoZyMQLzxivAUqrJM5eKxcsOgWuuPgI8A0tLS3A2TigjjS9mkL4un3HGcH6hkG9d2pWhrHgweDNdfD506lXveY9uDbdtg4UJYsICMBQvI+PVX85qV4tCV4pgzdTjV17SARzebZH/8OABP2758a816/DPjX+w65dQyL7uq0Zncct3zNDi0h9P2bWfIovEkH8jnmdteKBeCu5+jXZxy9WtWMbJJS3QL1RX+VqCJw+eNgW0hOrewmC9lkyfbK975zsBdsu+yIZtZnz4I778PXbvC66+D87EbNsBNN5k3hG0O/z0ffhhSUmDAABg3DpKT4YUXYMEC4g//TfXZ31G9/qlw9Ch07AhDh8KHH8J338GaNUxbso7rh00gp3ErUpISuOmipqV3EkkJ8fzUvB3TWnfh3fSBDO77OO23r2PS2Edh3TqPP8fKxcepd+jESKq7793xZ+Xtyl02aYluobrCnwYMVkpNwEzW7tdabw/RuYXFfOkS6ctdgCveyinji4t4bMFnDPpxCvtbnA2ffQuvvgqPPQbLlsEnn5hE/fzz8H//B/Hx5o1g1ix480247TaYNw/S0szzHTpAZadfm06dICfHbQx9gb4dW7h93vEOJadjT5ZndKTz0Lvgootg6lS45BLgxM9xyKSfKdaaN6e/zpVrFnLuA+PZm1iLFA8/K1+v3GWTlugWqLLML4EuQF2l1FbMnWw8gNZ6FDATU5K5HlOWeXsgzisih7cukUN7tiyTkMC3XvGeElFaSQHDxz9L++3r2Nj/Fk4fOwoSEmDyZHjtNcjMhJUrYc8eOHQI7roLnnnG/P3OO+GOO2DiRNi0CW65xSR9J4EoN3X5s+nUBq64Arp3h/Hj4dprS48FmD7yU65csxCAW1dOZ1S3Wzz+rHwdMjvZN14RGQKS8LXWA708r4H7A3EuEZ1Otle8Y4KqVXiQ/dVqgFK0jj/GV58/BYUFMHkyp1999YkvUsoMu5x/vknkl14KI0ZA69Ynjpk/Hz74wLwp/P03NG1a7txBHe9u0QKWLIG+feGGG6BOHejWzbx2w0r0mv0OGxs0Iy+hNrflTOf0kc9ypYdz+nrlfrJvvCIyKO1l3M8qaWlpOjs72+owRJizJ924QwdZ/v4tvN/xOj7tdD3zZ79M/Z9XmInW888/+RNs2mSGdh59FJo1K/NU+oi5Lq+GU5ISWJzZzef4Pb7JFRRAejrk5cGiRVC/PnTpAn/9Zd4QDh40Qz7vvAMPPOD2PPZYqxYdJb6kmENVE93GKj2CIptSaqXWuvztKNI8TUQ4eyKa8XEW1YuOcP/yr+jcrwv1f1xkkqA/yR6geXPzOi74O97t0x1CUhJ8+60Zz+/d23y+ebN5rF07c8zFF8Mbb0CbNrBrl/mzc+eJv+/axawt22DXLmocK6QExZLT2jG9XXcuefyecnHJJi3RS67wRXQYO9ZMsIJJiocOQX6++XuQ+HuFX6Gvz8mBzp2hqAimT4fLLjvx3NSpkJFR9vhKlaBePXNHYPuzgUTm7NEcP3iIvmsW0mTfdg5Urc7ETv1p9O8n6NPpbJ++bxHe5ApfRL81a0yFzYAB8NlnJiEGMdmD/+PdFbpDSE2FxYtNff+555Z9rm9fmD3bVA/ZE3ydOibpOzjD9icrJ48ek3+h9eZfuefHKdw951P2LZ7CvDGT6Dqwp0+xi8gkCV9EhzVr4MwzTX381KlmotMLX8aqPR3j76bkFa6IsQ/hOFOq7BW/FyNnraXweAkrG7dmUOPWtNmxntGTn6PRg4PosmEUD1/hvueQiGyS8EVA+ZtET9qaNabKpmlTM25dpYrXOL2Nn/tyjD/j3VZVxDjfQfzasAWZvR7g06+epf+MTxh29A5AVtZGI+mHLwLGl344Ad828PBhePBBk/BTU81jVauaq14PfFlRGuxVpxXp3xNIru4g5p9xPhPbXs4/l0/mrD9X8/0HE6B6dbNGQUQNucIXAePL4h6vx+zfbxJ2tWreT7h8uamj/+MPk/SHDPE5Vl/Gz0Ox6tTTHUKwyiNd3VkAvND9Li7ZnMvrM94kr1Z982b6/fdw3XV+n1OEB7nCFwHjdxI9cAAaNYLERDjjDLPSdMgQGD3a1KDv3m0OPnYMnnzSlCMeOWKS0ttvm1W0PvK25aGvxwSLqzuhRybm0iwAHSwd7ywcHaxanSf+8SAt9m7l0k22/Yk2bvTjuxDhRhK+CBi/k+iuXVBYCH36mPr5vDyz2vXuu02/mnr1oG5dswr1pZfg1lvhl19KV6BWhC8N3Xw5Jljc9fWHAAyDYZL+4sxuvHV9hzLf46LmqUw4t/eJAyXhRxVJ+CJg/E6iBw6YB+66CyZMgNxcU0+/cSPMnAlvvMGmS3uxJDGZu68eTvpZN5G
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-4, 4, 100)\n",
"y = 1 - (x/3) ** 2\n",
"x1 = x + np.random.randn(100) * 0.1\n",
"y1 = y + np.random.randn(100) * 0.1\n",
"plt.scatter(x1,y1)\n",
"idx = x1.argsort()\n",
"beta,avg,res = 0.7,0,[]\n",
"for i in idx:\n",
" avg = beta * avg + (1-beta) * y1[i]\n",
" res.append(avg/(1-beta**(i+1)))\n",
"plt.plot(x1[idx],np.array(res), color='red');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtEAAAHiCAYAAAAuz5CZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iTVRvH8e9pKTRllQ0tG9kgFCpDcLBBVkV8AVFBRUDEiSjgHkABxYELRMEByrSioAwBlVFZBZElG9qyocxC13n/eJKSpkm60qRp78919Wrz5El6CumvJ+c55z5Ka40QQgghhBAi83w83QAhhBBCCCG8jXSihRBCCCGEyCLpRAshhBBCCJFF0okWQgghhBAii6QTLYQQQgghRBZJJ1oIIYQQQogskk608ApKqSNKqY6ebocQQoiMSWaLgkA60SJfy60gV4ZJSqlz5o/JSinl4Ny7lVIpSqkrVh+DXN0mIYTwdnkks8fZ5HW8OcPLurpdwrsV8nQDhPBSQ4EwoAmggZXAIeBzB+fHaq0ru6ltQggh0sp0ZmutJwATLLeVUm8Ad2qtz7qlpcJryEi08Ca3KaV2K6UuKKVmKaX8AZRSPZRS25VScUqpDUqpW83HvwWqAj+bRxNeNB9foJQ6qZS6qJT6UynVMBttGQS8p7WO1lrHAO8Bg13yUwohRP7g9ZltHq1+CPg6G99T5HPSiRbeZCDQBagF1AFeUUo1A74ChgFlgOnAEqVUEa31Q8AxoKfWupjWerL5eX4FagPlgW3AHMs3UEqNMQe73Q+rtjQEdljd3mE+5kh5pdQppdRhpdT7SqmiOfh3EEIIb+DNmW1xB1ABWJTln17ke0pr7ek2CJEhpdQRIFxr/bn59j3ANGAFcFZr/arVufuAoVrrP8yPG6K1XuXgeQOBC0Cg1vpiFtqTDDTUWu81364N/Af4aJtfKqVURaA0sBeohjGisUdrPSyz308IIbyJN2e2zeO+BHy11oMz+71EwSEj0cKbHLf6+igQhNEpHWUz8lDFfF86SilfpVS4UuqgUuoScMR8V1YXjFwBSljdLgFcsRfGWuuTWuvdWusUrfVh4EWgbxa/nxBCeBuvzGyr720C7kemcggHpBMtvEkVq6+rArEYIT1eax1o9RGgtf7efJ5tQD4A9AY6AiWB6ubjCuyuyk7zYfU8uzAWqFg0MR/LDG35fkIIkY95e2b3Ac4DazPzw4qCRzrRwps8qZSqrJQqDYwD5gFfAMOVUi2VoahSqrtSqrj5MaeAmlbPURy4AZwDArBagQ3GqmzzXDy7H1anfgM8r5QKVkoFAaOA2fYarYwSd1XN7asChAM/5fQfQwgh8jivzGwrg4BvnI1Wi4JNOtHCm8zFmE93yPzxjtZ6C/A48DHGPLkDpF1xPRFjMUucUuoFjCA9CsQAu4HIbLZlOvAzsBP4F1hqPgaAeRTkDvPNZsBG4CqwwXz+09n8vkII4S28NbNRSgUD7c3fXwi7ZGGhEEIIIYQQWSQj0UIIIYQQQmSRdKKFEEIIIYTIIpd0opVSXymlTiul/nVw/0Cl1D/mjw1KqSZW9x1RSu1Uxu5FW1zRHiGEEI5JZgshRM65aiR6NtDVyf2Hgbu01rcCbwMzbO5vp7VuqrUOdVF7hBBCODYbyWwhhMiRQq54Eq31n0qp6k7u32B1MxKo7IrvK4QQIusks4UQIudc0onOoseAX61ua2CFUkoD07XWtiMe6ZQtW1ZXr149l5onhBC5Z+vWrWe11uU83Y4skMwWQhRYzjLbrZ1opVQ7jEBua3W4jdY6VilVHliplNqrtf7TzmOHAkMBqlatypYtMhVPCOF9lFJHPd2GzJLMFkIUdM4y223VOZRStwIzgd5a63OW41rrWPPn08CPQAt7j9daz9Bah2qtQ8uV86ZBHCGE8D6S2UII4ZxbOtFKqarAYuAhrfV/VseLWrb6VEoVBTpj7CQkhBDCQySzhRAiYy6ZzqGU+h64GyirlIoGXgf8ALTWnwOvAWWAT5VSAEnmVd0VgB/NxwoBc7XWv7miTUIIIeyTzBZCiJxzVXWOARncPwQYYuf4IaBJ+kcIIYTILZLZQgiRc7JjoRBCCCGEEFnkiRJ3QniViKgYpizfR2xcPEGBJkZ3qUtYSLCnmyWEEMIOyWzhLtKJFsKJiKgYxi7eSXxiMgAxcfGMXbwTQEJZCCHyGMls4U4ynUMIJ6Ys35caxhbxiclMWb7PQy0SQgjhiGS2cCfpRAvhRGxcfJaOCyGE8BzJbOFO0okWwomgQFOWjgshhPAcyWzhTtKJFsKJ0V3qYvLzTXPM5OfL6C51PdQiIYQQjkhmC3eShYVCOGFZiOJopbesAhdCiLxDMlu4k3SihchAWEiw3ZCVVeBCCJH3SGYLd5FOtBDZ5GwVeFYCWUZGhBAi90lmC1eTTrQQ2eSKVeAyMiKEEO4hmS1cTRYWigIpIiqGNuGrqTFmKW3CVxMRFeP8AbNmwe7daQ65YhW41DQVQojMyXJu25DMFq4mnWhR4FhGEmLi4tHcHElwGMjR0fDoo9CkCbzwAly6BLhmFbjUNBVCiIxlObftkMwWruaSTrRS6iul1Gml1L8O7ldKqY+UUgeUUv8opZpZ3TdIKbXf/DHIFe0RBVdmRiqyPJKQmGh8rlcPpk41Ps+ZQ1jTICb2aUxwoAkFBAeamNincZYu6UlNU+EJktkir8js6LIrRoDDQoIls4VLuWpO9GzgY+AbB/d3A2qbP1oCnwEtlVKlgdeBUEADW5VSS7TWF1zULlGAZHauWpZHErQ2Pr/wAtSvD08+CQ8+CNOnE/bxx4SNaZ/tNo/uUjdNm0Fqmgq3mI1ktvCwrMwvdtUIsKPKHZklmS2suWQkWmv9J3DeySm9gW+0IRIIVEpVAroAK7XW580hvBLo6oo2iYInsyMVWR5JSEkxPvv4QIsW8PffMGOGMUe6SRMoXhyWLs1Wm10xMiJEVklmi7wgK6PLWc7t8eNhwAB45x1YvBj27YOkpBy3WTJbWHNXdY5g4LjV7WjzMUfHhXDKXomhzI5U2BtJUBijIG3CV6crV7Ry1wk6Ac/N38GmmGDj/scfZ1Xp2nTs2w6uXCGpVy92vhxOyFujM91ey/fI6ciIELlAMlu4VE4yG7KW2xFRMbSe8iGBV+Io8sMPqecn+xXmcJlg9pUMJrJZO0Kfe4zet1XLUpsteS2ZLcB9nWhl55h2cjz9Eyg1FBgKULVqVde1TOR5tkHWrl45Fm2NSXcJMDDAjwvXEtM93nakwnpHq5i4eBQ3X3S2lxMjomKYvnwfnYBkpVLv33L0PKs2nqGj+XG7ytci5O0X2XsylnrTp4JSqW1/Y8ku4uJvtismLp7n5m1ny9HzvBPW2DhomTKi7P1KCOF2ktkiR6xzu6TJj6sJSSQmGy+VrGY2ZD63AcYu3sma5GQWN2zHWx2G0vBiDP2LXSZu83aGbFzILScP033fOmKXfs6/gx+n0RsvQKlSadpt73uky21R4LmrOkc0UMXqdmUg1snxdLTWM7TWoVrr0HLlyuVaQ0XeYm9F9pzIY3YvAWptzE2rcPksShtTMBzNVQsLCWb9mPYEB5rS9QCsLydOWb6PhETLJUCVev/3fx/nrCqc+pgH+o9nfuOO1PviA3jsMThwgF9XbeeNeVuIu5aQ7vtrYE7ksZuLaEaMgHbtICH9uUJ4gGS2yDbb3I6LT0ztQFtYZ7Y1Z/OLM5PblikiPlqjlQ/xhf3ZUq4WLwWE8M6dg1PP/7RVXw6XqkSjjyZA5crQtSvH7+nDpWFP0veXL6l17ni675Eut0WB565O9BLgYfOK71bARa31CWA50FkpVUopVQrobD4mCoDsVtKwO+wFXIxPZFqzANZNH8KgbUszNVct9sI1msTuY8yar5iy9APKXI0zjpsvJ549e4mmsf+l+77JWpPke/NCztUiAbzY7Rk+aDPAqClduzbdOoWwfVIY/717L712/wGAb0oyQZdOpz7flOX74MQJmDkT/vgD3n7b2T+ZEO4imS3SyUklDXsuxidma36xs2kglvt
"text/plain": [
"<Figure size 864x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-4, 4, 100)\n",
"y = 1 - (x/3) ** 2\n",
"x1 = x + np.random.randn(100) * 0.1\n",
"y1 = y + np.random.randn(100) * 0.1\n",
"_,axs = plt.subplots(2,2, figsize=(12,8))\n",
"betas = [0.5,0.7,0.9,0.99]\n",
"idx = x1.argsort()\n",
"for beta,ax in zip(betas, axs.flatten()):\n",
" ax.scatter(x1,y1)\n",
" avg,res = 0,[]\n",
" for i in idx:\n",
" avg = beta * avg + (1-beta) * y1[i]\n",
" res.append(avg)#/(1-beta**(i+1)))\n",
" ax.plot(x1[idx],np.array(res), color='red');\n",
" ax.set_title(f'beta={beta}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def average_grad(p, mom, grad_avg=None, **kwargs):\n",
" if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)\n",
" return {'grad_avg': grad_avg*mom + p.grad.data}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def momentum_step(p, lr, grad_avg, **kwargs): p.data.add_(-lr, grad_avg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_func = partial(Optimizer, cbs=[average_grad,momentum_step], mom=0.9)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.856000</td>\n",
" <td>2.493429</td>\n",
" <td>0.246115</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.504205</td>\n",
" <td>2.463813</td>\n",
" <td>0.348280</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.187387</td>\n",
" <td>1.755670</td>\n",
" <td>0.418853</td>\n",
" <td>00:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit_one_cycle(3, 0.03)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt0AAAD4CAYAAAAwyVpeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3yV9fn/8deVTfYkjEz2JpAQQETcolVRBAGto9pqtWqdraNaZx1UqbXan7TiqBNwoYI4EBVkJIywAyEJJKwkZJKdnM/vjxz8pjRAgJzcZ1zPx+M8PLnPfd95H4STK/d9fT4fMcaglFJKKaWUchwvqwMopZRSSinl7rToVkoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysF8rA7QGaKjo01SUpLVMZRS6oStWbOmxBgTY3WOzqSf2UopV3Wsz2yPKLqTkpLIzMy0OoZSSp0wEdlldYbOpp/ZSilXdazPbG0vUUoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysG06FZKKaWUUsrBHFp0i8hEEckWkRwRub+N1/1F5AP766tEJMm+PV1E1tsfWSJyeXvPqZRSSimllLNxWNEtIt7Ay8CFwCBghogMOmK3G4EyY0wfYBbwrH37JiDNGJMCTAReFRGfdp5TKaWUUkopp+LIebrTgRxjTC6AiLwPTAK2tNpnEvCo/fl84B8iIsaYmlb7BADmBM6pOlh9UzMrc0vJ3l9JTUMzgX7eBPn7EOzvQ3igHwmRgfQM74Kfj3YrKaWsMevr7UQG+TEsLoyU+HBExOpISikXUVHTyP7KOg5U1lFUVU9pdT01Dc3MSE8gNjSgw76PI4vunkBBq68LgdFH28cY0yQiFUAUUCIio4E5QCJwjf319pwTABG5CbgJICEh4dTfjQdqthneWbWLv32zg9LqhmPu6yXQPawL/buFMCI+nBEJEQyPDyMkwLeT0iqlPFVjs43ZP+RS29gMwIBuITx66WDG9IqyOJlSytkUVdWxMreUzXsr2LK3kq37qig5VN/mvuP7xrhM0d3WZQbT3n2MMauAwSIyEHhTRBa185zYj58NzAZIS0trcx91dBW1jdz5/jq+yy5mbK8ofnNGMqmJkQT7+1Db2Ex1fRPV9U2UHGpgd2kNuw9Ws6u0hs17K1myrQgAEUiJD+fcgbGcNyiWvl2D9eqTUqrD+Xp7seXxCyiuqmfJtiJeXprD9Nkruff8fvzurD76uaOUB2tqtrEi9yDfZxezLKeEbfurAPDz9qJvbDBn9o+hf2wI3cMDiA0NoGuIP5FBfnTx9cbHu2Pv4Duy6C4E4lt9HQfsPco+hSLiA4QBpa13MMZsFZFqYEg7z6lOUWVdI9e8toqt+yp54rIh/HJ0wn/90Aq2t5YA9IqB9OTI/zq+oraRrIJyMneVsTS7iJmLs5m5OJuEyEAuS+nB1LR44iMDO/U9KaXcm4jQNTSA6ekJXJrSg4c+3sRfv9pOVV0TD1w00Op4SqlOZIxh7e5yFqzfwxcb91FyqAE/Hy9GJUXwx4kDOL1PNAO6h+DbwUX18Tiy6M4A+opIMrAHmA5cdcQ+C4DrgBXAFGCJMcbYjymwt5QkAv2BfKC8HedUp6Cx2cZNb2WyZW8lr16TyjkDY0/4HGFdfDmjXwxn9Ivh7vP6caCyjm+3FrFo0z5e+i6Hl77LYVzvaKaNimfikG6d/pdeKeXeAv18eOHK4YQE+PDqD7l0Dwvg+nHJVsdSSjlYdX0TH64t5I3l+eSWVOPv48W5A2O5ZHgPJvSLoYuft6X5HFZ02wvm24DFgDcwxxizWUQeBzKNMQuA14D/iEgOLVe4p9sPPx24X0QaARtwqzGmBKCtczrqPXiiZxdtY2VuKS9cOfykCu62xIYGcNXoBK4ancCe8lrmZRYwL7OQ299bR1xEF246oxdTU+Mt/8eglHIfIsKfLxnMvoo6nvhiK8PiwxmZEGF1LKWUA+wtr+X15Xm8n1FAVV0Tw+PDmTllGBOHdHOqsWVijPu3O6elpZnMzEyrYzi9H7YXc+2c1Vw3NpHHJg1x6Pey2Qzfbivin0tzWLu7nKggP244PZkbxiVr8a1UKyKyxhiTZnWOztSRn9mVdY1c9OKPiMCXvz+DIH9H3uBVSnWmoso6Xlm6k3dX7abZGC4c0o0bTk+29BfsY31m66ePAqCqrpH7P9xA75igTul/9PISzhsUy7kDu5KRX8YrS3OYuTibt1bkc/d5/ZiSGo+3lw5+UkqdmtAAX2ZNS2Hq/1vB377ZzkO/0KUdlHJ1FTWNvLI0hzdX5NPYbJiaGsdtZ/chLsK5x4tp0a0AeGlJDvsq65j/29MI8O28K80iQnpyJOnJ6WTkl/KXhVv544cbmbMsn4d+MZAz+sV0WhallHsalRTJ9FHxzFmezxWpcQzoFmp1JKXUSbDZDHMzC3hucTZlNQ1cltKT35/Tl6ToIKujtYuOYFPklVTz+vI8poyMIzXRulsyo5Ii+eiW03jl6pHUNTVz7ZzV3PHeOoqr2p4/Uyml2uuPEwcQ7O/D0wu3WR1FKXUSsgrKufyV5dz/0UZ6xwTxxe3jmTUtxWUKbtCiWwEzF2/Dz9uL+yb2tzoKIsJFQ7vz1V1ncOe5ffly037OeX4p76/ejc3m/uMPlFKOERHkx61n9ub77cX8tLPE6jhKqXaqa2zmic+3cNkry9lbUcesacOZe/NYBvVwvTtWWnR7uM17K1i4cT83np5M15COW3XpVPn7eHPnuf1Y+PvxDOweyv0fbeSaOavYV1FrdTSllIu67rQkuocF8PxX2/GESQSUcnVrd5dx0Ys/8tqyPK5KT2DJPRO4fEScyy54pUW3h/vbNzsIDfDhxvG9rI7Spj5dg3n/pjE8PXko63aXM/FvP/LFhn1Wx1JKuaAAX29+O6E3a3aVsTqv9PgHKKUs0dBk45lF25jyz5+ob7Lx9o2jeeryoU41/d/J0KLbg+UWH+LrLQe4flwyYV2c9y+yiDAjPYGFd4wnOTqI3727lrvnrudQfZPV0ZRSLmbaqHiig/34x3c5VkdRSrWhoLSGqa+u4P99v5Mr0+L58s7xnN432upYHUKLbg/2+vJ8/Ly9uGZMotVR2iUpOoh5vx3L78/pyyfr9nDpP5ax40CV1bGUUi4kwNebX41L5scdJeQU6eeHUs7kq837+cXffyS36BCvXD2SZ64Y5vJXt1vTottDVdQ0Mn9NIZem9CAmxN/qOO3m6+3FXef1451fj6GytpFJLy/ns6y9VsdSSrmQaaPi8fP24u2Vu62OopQCmpptPPn5Fm76zxoSo4L4/I7TuWhod6tjdTgtuj3Uexm7qW1s5oZxyVZHOSlje0fx+e0tgyxvf28dj3+2haZmm9WxlFIuIDrYnwuHduPDNYXUNGibmlJWqqhp5FdvZPDvZXlcOzaR+beMJTHKdaYBPBFadHugxmYbb/6Uz9heUS455c5h3cICeO83Y7j+tCTmLM/jV29kUFnXaHUspZQLuGZMIlX1TXy6Xu+UKWWVnKIqJr28jJW5B3nuimE8PmkI/j6dt0BfZ9Oi2wN9veUA+yrquPF017zK3ZqfjxePXjqYZyYPZcXOg1zxyk8UlNZYHUsptyIiE0UkW0RyROT+Nl5PFJFvRWSDiCwVkbgjXg8VkT0i8o/OS31sqYkRDOgWwn9W7NLpA5WywNLsIi5/+ScO1Tfx3m/GcOWoeKsjOZwW3R5oXmYB3cMCOGtAV6ujdJjp6Qm8dUM6ByrruPyV5azdXWZ1JKXcgoh4Ay8DFwKDgBkiMuiI3f4KvGWMGQY8Djx9xOtPAN87OuuJEBF+OSaRLfsqWVdQbnUcpTzK3IwCbnwzk/jIQD697XTSkiKtjtQptOj2MEWVdXy/vZjJI3vi7eWak8sfzWl9ovno1nEE+fswY/ZKvtlywOpISrmDdCDHGJNrjGkA3gcmHbHPIOBb+/PvWr8uIqlALPBVJ2Q9IZeN6EmgnzdzMwqsjqKURzDG8OI3O/jDhxs4rXcUc387lp7hXay
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_sched()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RMSProp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
" return {'sqr_avg': sqr_avg*sqr_mom + p.grad.data**2}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):\n",
" denom = sqr_avg.sqrt().add_(eps)\n",
" p.data.addcdiv_(-lr, p.grad, denom)\n",
"\n",
"opt_func = partial(Optimizer, cbs=[average_sqr_grad,rms_prop_step],\n",
" sqr_mom=0.99, eps=1e-7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.766912</td>\n",
" <td>1.845900</td>\n",
" <td>0.402548</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.194586</td>\n",
" <td>1.510269</td>\n",
" <td>0.504459</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.869099</td>\n",
" <td>1.447939</td>\n",
" <td>0.544968</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = get_learner(opt_func=opt_func)\n",
"learn.fit_one_cycle(3, 0.003)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adam"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Decoupled weight_decay"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Callbacks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a callback"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ModelReseter(Callback):\n",
" def begin_train(self): self.model.reset()\n",
" def begin_validate(self): self.model.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RNNRegularizer(Callback):\n",
" def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta\n",
"\n",
" def after_pred(self):\n",
" self.raw_out,self.out = self.pred[1],self.pred[2]\n",
" self.learn.pred = self.pred[0]\n",
"\n",
" def after_loss(self):\n",
" if not self.training: return\n",
" if self.alpha != 0.:\n",
" self.learn.loss += self.alpha * self.out[-1].float().pow(2).mean()\n",
" if self.beta != 0.:\n",
" h = self.raw_out[-1]\n",
" if len(h)>1:\n",
" self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]\n",
" ).float().pow(2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Callback ordering and exceptions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TerminateOnNaNCallback(Callback):\n",
" run_before=Recorder\n",
" def after_batch(self):\n",
" if torch.isinf(self.loss) or torch.isnan(self.loss):\n",
" raise CancelFitException"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Foundations of Deep Learning: Wrap up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}