mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 18:00:48 +00:00
986 lines
146 KiB
Plaintext
986 lines
146 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"hide_input": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#hide\n",
|
||
|
"from utils import *"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "raw",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"[[chapter_accel_sgd]]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Variants of SGD"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now that you know all about how the architectures are put together, it's time to start exploring the training process.\n",
|
||
|
"\n",
|
||
|
"We explained earlier the basis of Stochastic Gradient Descent: pass a minibatch in the model, compare it to our target with the loss function then compute the gradients of this loss function with regards to each weight before updating the weights with the formula:\n",
|
||
|
"\n",
|
||
|
"```python\n",
|
||
|
"new_weight = weight - lr * weight.grad\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"We implemented this from scratch in a training loop, and also saw that Pytorch provides a simple `nn.SGD` class that does this calcuation for each parameter for us. Let's now build some faster optimizers, using a flexible foundation."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Let's start with SGD"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"First, we'll create a baseline, using plain SGD, and compare it to fastai's default optimizer. We'll start by grabbing Imagenette with the same `get_data` we used in <<chapter_resnet>>:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#hide_input\n",
|
||
|
"def get_data(url, presize, resize):\n",
|
||
|
" path = untar_data(url)\n",
|
||
|
" return DataBlock(\n",
|
||
|
" blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, \n",
|
||
|
" splitter=GrandparentSplitter(valid_name='val'),\n",
|
||
|
" get_y=parent_label, item_tfms=Resize(presize),\n",
|
||
|
" batch_tfms=[*aug_transforms(min_scale=0.5, size=resize),\n",
|
||
|
" Normalize.from_stats(*imagenet_stats)],\n",
|
||
|
" ).dataloaders(path, bs=128)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dls = get_data(URLs.IMAGENETTE_160, 160, 128)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We'll create a ResNet34 without pretraining, and pass along any arguments received:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_learner(**kwargs):\n",
|
||
|
" return cnn_learner(dls, resnet34, pretrained=False,\n",
|
||
|
" metrics=accuracy, **kwargs).to_fp16()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Here's the default fastai optimizer, with the usual 3e-3 learning rate:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: left;\">\n",
|
||
|
" <th>epoch</th>\n",
|
||
|
" <th>train_loss</th>\n",
|
||
|
" <th>valid_loss</th>\n",
|
||
|
" <th>accuracy</th>\n",
|
||
|
" <th>time</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" <td>2.571932</td>\n",
|
||
|
" <td>2.685040</td>\n",
|
||
|
" <td>0.322548</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>1</td>\n",
|
||
|
" <td>1.904674</td>\n",
|
||
|
" <td>1.852589</td>\n",
|
||
|
" <td>0.437452</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>2</td>\n",
|
||
|
" <td>1.586909</td>\n",
|
||
|
" <td>1.374908</td>\n",
|
||
|
" <td>0.594904</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn = get_learner()\n",
|
||
|
"learn.fit_one_cycle(3, 0.003)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now let's try plain SGD. We can pass `opt_func` (optimization function) to `cnn_learner` to get fastai to use any optimizer:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"learn = get_learner(opt_func=SGD)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The first thing to look at is `lr_find`:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0.017378008365631102, 3.019951861915615e-07)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhc9X3v8fdXI412a7HlfTfGYDYbGxNKYxxCEhIIJCHNpU3a0NBCkiYkN0/TNM29tOU+Cbmla5qnpQQuoQFSEghroEAbtkAM2JjF7LaxZEm2JWvfpdF87x8zkoUsyZLRmTmj+byeZx7PnDkz58NIzFe/81uOuTsiIpK9ctIdQERE0kuFQEQky6kQiIhkORUCEZEsp0IgIpLlVAhERLJcbroDTNWcOXN8+fLl6Y4hIpJRtm/ffsjdq8Z6LuMKwfLly9m2bVu6Y4iIZBQzqx7vOZ0aEhHJcioEIiJZToVARCTLqRCIiGQ5FQIRkSynQiAikuVUCEREMsCjrx1kV0NnIO+tQiAiEnLuzpdv285dL9QG8v4qBCIiIdfRF2Ng0Kksigby/ioEIiIh19LVD0BFsQqBiEhWakoWgtkqBCIi2UktAhGRLKcWgYhIllOLQEQkyzV39RPNzaE4Ggnk/VUIRERCrrmrn8qiKGYWyPurEIiIhFxzVz+VAZ0WAhUCEZHQa+5WIRARyWpqEYiIZDkVAhGRLDYwGKejN6ZCICKSrYKeQwAqBCIiodbcHeysYlAhEBEJtebOZIsgoCWoQYVARCTUhlsEJSoEIiJZqblLLQIRkax2uBDkBXYMFQIRkRBr7uqnrDCP3EhwX9cqBCIiIdbc1R/oiCFQIRARCbXmrv5A5xCACoGISKgFvbwEqBCIiITa0LUIgqRCICISUu5OS3c/lQHOIQAVAhGR0OrsizEw6GoRiIhkq+YULDgHkBvkm5vZXqADGARi7r5x1PNbgHuBd5KbfuHu1wSZSUQkUwwVgqCHjwZaCJI+4O6HJnj+KXe/MAU5REQySqpaBDo1JCISUqlqEQRdCBx4xMy2m9kV4+xzlpm9ZGYPmdlJAecREckYM6KPADjb3evNbC7wqJm94e5Pjnj+BWCZu3ea2ceAe4DVo98kWUSuAFi6dGnAkUVEwqG5u59obg7F0Uigxwm0ReDu9cl/G4C7gU2jnm93987k/QeBPDObM8b73ODuG919Y1VVVZCRRURCo7kzMZnMzAI9TmCFwMyKzax06D7wYWDnqH3mW/K/0Mw2JfM0BZVJRCSTtHQHv7wEBHtqaB5wd/J7Phe43d3/08y+CODu1wOfBr5kZjGgB7jU3T3ATCIiGaMpBesMQYCFwN33AKeNsf36Efd/CPwwqAwiIpmspaufJRVFgR9Hw0dFREIqVS0CFQIRkRAaGIzT0RtTIRARyVYtKZpDACoEIiKh1NydmlnFoEIgIhJKzZ3JFkHAS1CDCoGISCg1dPQBUFWaH/ixVAhERELoYHsvAPPLCgI/lgqBiEgIHWjvpTgaoSQ/+KsFqBCIiIRQQ3sf81LQGgAVAhGRUDrQ3sv8WSoEIiJZ62B7L/NUCEREspO7J04NqRCIiGSn5q5++gfjzJsV/NBRUCEQEQmdg+2JOQTqIxARyVJDcwjmqhCIiGSnVE4mAxUCEZHQOZAsBFUl6iMQEclKB9t7mVMSJZqbmq9oFQIRkZA5mMKho6BCICISOgfaUjeZDFQIRERCp6FDhUBEJGv1x+Ic6uxP2WQyUCEQEQmVxs7UTiYDFQIRkVA50JYYOqpTQyIiWWpoMpkKgYhIljpcCNRHICKSlQ609xKN5FBZHE3ZMVUIRERCpKG9j7mz8jGzlB1ThUBEJERSPZkMVAhERELlYAqvVTxEhUBEJEQOtvcyN4UdxaBCICISGh29A3T1D6pFICKSrYYuUak+AhGRLJWOyWQQcCEws71m9oqZvWhm28Z43szsB2a2y8xeNrPTg8wjIhJmQ8tLpOoSlUNyU3CMD7j7oXGe+yiwOnk7E/jX5L8iIlnnYEfqZxVD+k8NXQz8uydsBcrNbEGaM4mIpMXBtl5K83Mpiqbib/TDgi4EDjxiZtvN7Ioxnl8E7BvxuDa57V3M7Aoz22Zm2xobGwOKKiKSXnWtvSwsL0z5cYMuBGe7++kkTgH9iZltHvX8WHOo/YgN7je4+0Z331hVVRVEThGRtKtv7WFRxQwrBO5en/y3Abgb2DRql1pgyYjHi4H6IDOJiIRVXWsPC8tT21EMARYCMys2s9Kh+8CHgZ2jdrsP+IPk6KH3AW3uvj+oTCIiYdXZF6OtZ4BF5UUpP3aQPRLzgLuTK+jlAre7+3+a2RcB3P164EHgY8AuoBv4wwDziIiEVn1rD0BaWgSBFQJ33wOcNsb260fcd+BPgsogIpIp6pKFYPFM6yMQEZHJqWsZahGoEIiIZKX61h5yc4y5pTOos1hERCavrrWHBeUFRHJSd2WyISoEIiIhUNfSw8Ky1J8WAhUCEZFQSNdkMlAhEBFJu4HBOAfae1mUho5iUCEQEUm7g+29xB0VAhGRbJXOoaOgQiAiknb1bYlCoD4CEZEsNdwi0KghEZHsVNfay+ziKIXRSFqOP6lCYGarzCw/eX+LmV1lZuXBRhMRyQ51aRw6CpNvEdwFDJrZccBNwArg9sBSiYhkkfrW9E0mg8kXgri7x4BPAv/o7v8T0LWFRUTeI3enriUzWgQDZva7wOeBB5Lb8oKJJCKSPVq7B+gZGEzb0FGYfCH4Q+As4Lvu/o6ZrQBuDS6WiEh2GLoOQbomk8EkL0zj7q8BVwGYWQVQ6u7fDzKYiEg2CEMhmOyoocfNbJaZVQIvATeb2d8HG01EZOYbmkOQCX0EZe7eDnwKuNndNwDnBRdLRCQ71Lf2UJCXQ0VR+rpdJ1sIcs1sAfAZDncWi4jIe1TX2sOi8kLMUn9BmiGTLQTXAA8Du939eTNbCbwdXCwRkeyQmExWlNYMkyoE7v5zdz/V3b+UfLzH3S8JNpqIyMzm7uw91MXiNPYPwOQ7ixeb2d1m1mBmB83sLjNbHHQ4EZGZbM+hLtp7Y5y2uCytOSZ7auhm4D5gIbAIuD+5TUREjtGOmlYA1i+tSGuOyRaCKne/2d1jyduPgaoAc4mIzHg7aloozc/luKqStOaYbCE4ZGafM7NI8vY5oCnIYCIiM92OmlbWLS0nJyd9I4Zg8oXgCySGjh4A9gOfJrHshIiIHIOuvhhvHGhn/ZL0r+g/2VFDNe5+kbtXuftcd/8EicllIiJyDF6ubSPu6e8fgPd2hbJvTFsKEZEss2NfCwDrMqVFMI70ntQSEclgO2paWTmnmIriaLqjvKdC4NOWQkQki7g7O2paWLc0/a0BOMoy1GbWwdhf+AakdyqciEiGqm3p4VBnfyj6B+AohcDdS1MVREQkW7xQk+gfOD0kLYL3cmpoUpLzDnaY2RGrlprZZWbWaGYvJm9/FHQeEZF021HTSmFehDXzwvG39qSuUPYefQ14HZg1zvN3uPtXUpBDRCQUdtS0cOriMnIjgf8tPimBpkguTHcBcGOQxxERyRS9A4O8Wt8emv4BCP7U0D8CfwbEJ9jnEjN72czuNLMlY+1gZleY2TYz29bY2BhIUBGRVHjjQAexuLNuSXpXHB0psEJgZhcCDe6+fYLd7geWu/upwH8Bt4y1k7vf4O4b3X1jVZXWuhORzFXd1AXAyjQvNDdSkC2Cs4GLzGwv8B/AuWZ268gd3L3J3fuSD38EbAgwj4hI2lU3dQOwtDK9VyUbKbBC4O7fdvfF7r4cuBT4lbt/buQ+yesgD7mIRKeyiMiMVd3UzbxZ+RTkRdIdZVgqRg29i5ldA2xz9/uAq8zsIiAGNAOXpTqPiEgq1TR3sayyON0x3iUlhcDdHwceT96/esT2bwPfTkUGEZEwqG7qZvPx4errDMcgVhGRLNDTP0hDRx/LQtQ/ACoEIiIpU9Oc7CierUIgIpKVhguBWgQiItlpaA7Bstnh6ixWIRARSZGa5m5K83OpKMpLd5R3USEQEUmR6qZuls4uwixcF3h
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.lr_find()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"It looks like we'll need to use a higher learning rate than we normally use:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: left;\">\n",
|
||
|
" <th>epoch</th>\n",
|
||
|
" <th>train_loss</th>\n",
|
||
|
" <th>valid_loss</th>\n",
|
||
|
" <th>accuracy</th>\n",
|
||
|
" <th>time</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" <td>2.969412</td>\n",
|
||
|
" <td>2.214596</td>\n",
|
||
|
" <td>0.242038</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>1</td>\n",
|
||
|
" <td>2.442730</td>\n",
|
||
|
" <td>1.845950</td>\n",
|
||
|
" <td>0.362548</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>2</td>\n",
|
||
|
" <td>2.157159</td>\n",
|
||
|
" <td>1.741143</td>\n",
|
||
|
" <td>0.408917</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.fit_one_cycle(3, 0.03, moms=(0,0,0))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"(Because accelerated SGD using momentum with is such a good idea, fastai uses it by default in `fit_one_cycle`, so we turn it off with `moms=(0,0,0)`; we'll be learning about momentum shortly.)\n",
|
||
|
"\n",
|
||
|
"Clearly, plain SGD isn't training as fast as we'd like. So let's learn the tricks to get accelerated training!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## A generic optimizer"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In order to build up our accelerated SGD tricks, we'll need to start with a nice flexible optimizer foundation. No library prior to fastai provided such a foundation, but during fastai's development we realized that all optimizer improvements we'd seen in the academic literature could be handled using *optimizer callbacks*. These are small pieces of code that an optimizer can add to the optimizer `step`. They are called by fastai's `Optimizer` class. This is a small class (less than a screen of code); these are the definitions in `Optimizer` of the two key methods that we've been using in this book:\n",
|
||
|
"\n",
|
||
|
"```python\n",
|
||
|
"def zero_grad(self):\n",
|
||
|
" for p,*_ in self.all_params():\n",
|
||
|
" p.grad.detach_()\n",
|
||
|
" p.grad.zero_()\n",
|
||
|
"\n",
|
||
|
"def step(self):\n",
|
||
|
" for p,pg,state,hyper in self.all_params():\n",
|
||
|
" for cb in self.cbs:\n",
|
||
|
" state = _update(state, cb(p, **{**state, **hyper}))\n",
|
||
|
" self.state[p] = state\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"As we saw when training an MNIST model from scratch, `zero_grad` just loops through the parameters of the model and sets the gradients to zero. It also calls `detach_`, which removes any history of gradient computation, since it won't be needed after `zero_grad`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The more interesting method is `step`, which loops through the callbacks (`cbs`) and calls them to update the parameters (the `_update` function just calls `state.update` if there's anything returned by `cb(...)`). As you can see, `Optimizer` doesn't actually do any SGD steps itself. Let's see how we can add SGD to `Optimizer`.\n",
|
||
|
"\n",
|
||
|
"Here's an optimizer callback that does a single SGD step, by multiplying `-lr` by the gradients, and adding that to the parameter (when `Tensor.add_` in PyTorch is passed two parameters, they are multiplied together before the addition): "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def sgd_cb(p, lr, **kwargs): p.data.add_(-lr, p.grad.data)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We can pass this to `Optimizer` using the `cbs` parameter; we'll need to use `partial` since `Learner` will call this function to create our optimizer later:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"opt_func = partial(Optimizer, cbs=[sgd_step])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's see if this trains:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: left;\">\n",
|
||
|
" <th>epoch</th>\n",
|
||
|
" <th>train_loss</th>\n",
|
||
|
" <th>valid_loss</th>\n",
|
||
|
" <th>accuracy</th>\n",
|
||
|
" <th>time</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" <td>2.730918</td>\n",
|
||
|
" <td>2.009971</td>\n",
|
||
|
" <td>0.332739</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>1</td>\n",
|
||
|
" <td>2.204893</td>\n",
|
||
|
" <td>1.747202</td>\n",
|
||
|
" <td>0.441529</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>2</td>\n",
|
||
|
" <td>1.875621</td>\n",
|
||
|
" <td>1.684515</td>\n",
|
||
|
" <td>0.445350</td>\n",
|
||
|
" <td>00:09</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn = get_learner(opt_func=opt_func)\n",
|
||
|
"learn.fit(3, 0.03)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"It's working! So that's how we create SGD from scratch in fastai."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Momentum"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"SGD is the idea of taking a step in the direction of the steepest slope at each point of time. But what if we have a ball rolling down the mountain? It won't, at each given point, exactly follow the direction of the gradient, as it will have *momentum*. A ball with more momentum (for instance, a heavier ball) will skip over little bumps and holes, and be more likely to get to the bottom of a bumpy mountain. A ping pong ball, on the other hand, will get stuck in every little crevice.\n",
|
||
|
"\n",
|
||
|
"So how could we bring this idea over to SGD? We can use a moving average, instead of only the current gradient, to make our step:\n",
|
||
|
"\n",
|
||
|
"```python\n",
|
||
|
"weight.avg = beta * weight.avg + (1-beta) * weight.grad\n",
|
||
|
"new_weight = weight - lr * weight.avg\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"Here `beta` is some number we choose which defines how much momentum to use. If `beta` is zero, then the first equation above becomes `weight.avg = weight.grad`, so we end up with plain SGD. But if it's a number close to one, then the main direction chosen is an average of previous steps. (If you have done a bit of statistics, you may recognize in the first equation an *exponentially weighted moving average*, which is very often used to denoise data and get the underlying tendency.)\n",
|
||
|
"\n",
|
||
|
"Note that we are writing `weight.avg` to highlight the fact we need to store thoe moving averages for each parameter of the model (and they all their own independent moving averages).\n",
|
||
|
"\n",
|
||
|
"Here is an example of noisy data for a single parameter, with the momentum curve plotted in red, and the gradients of the parameter plotted in blue. The gradients increase, and then decrease, and the momentum does a good job of following the general trend, without getting too influenced by noise:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"hide_input": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hU1dbA4d8mBEhACEhNAEFRBGnRWDCCFAUuIgYVBXvnqlgRDVe5dkWxtw8VvIKiwBUMCCgXBKQjwURRBKkKoYUSigQIyf7+2DNhMpkWppwp630enpCZkzkrgaw5Z++111Zaa4QQQkS/SlYHIIQQIjQk4QshRIyQhC+EEDFCEr4QQsQISfhCCBEjKlsdgDt169bVzZo1szoMIYSIKCtXrtytta7n6rmwTfjNmjUjOzvb6jCEECKiKKX+dPecDOkIIUSMkIQvhBAxQhK+EELECEn4QggRIyThCyFEjJCEL4QQMUISvhBCxIiwrcMXIhpl5eQxctZathUUkpyUwNCeLclITbE6LBEjJOELESJZOXkMm7KKwqJiAPIKChk2ZRWAJH0REjKkI0SIjJy1tjTZ2xUWFTNy1lqLIhKxRhK+ECGyraCwQo8LEWiS8IUIkeSkhAo9LkSgScIXIkSG9mxJQnxcmccS4uMY2rOlRRGJWCOTtkKEiH1iVqp0hFUk4YuoECnljhmpKWEZl4gNkvBFxPO13DFS3hSECBZJ+CLieSp3tCf0k62BlzcJEU1k0lZEPF/KHU+mBt7+JpFXUIjmxJtEVk5eQOIWItTkCl9EvOSkBPJcJH3Hckdf3hScr+b/Pnrc452DXP2LSCNX+CLi+VLu6K0G3tXVfEFhkcuv2VZQyFNZq3hkYu7JXf0fOQK//+7T9yZEIEnCFxEvIzWFl69uS0pSAgpISUrg5avblrna9vam4GrIx51aCfGMX/YX2ulxd0NEWTl5pI+YS/PMGXR6aTbrOnaH1q358MJrSH/5exkiEiETkCEdpdQnQB9gl9a6jYvnFfA20Bs4DNymtf4pEOcWkSHYwx+O5Y72cz0yMbfcudzF4Gt7g4T4OJSiXLK3c34d58nigd98xJm5S1jWpA2DfpxCXs16DD3Yl2e/+Y2Cw0Vefzb27y2voJA4pSjWmhQZThI+CtQY/qfAe8A4N8//AzjT9udC4P9sH0UMCGWXSFfnemRiLtl/7uWFjLZuz+duHqB2YjyJVSqXeZN4ZGKu2/M7Dx053jn0XrOI+5Z9xfgOvXiqx318PPl5npo7hl8anUVucsvSeN39bJy/t2KtvX6NEI4CMqSjtV4A7PVwyFXAOG0sA5KUUo0CcW4R/nytkHEc+kgfMfekhjpcnUsD45f95fH13A35PH3lOSzO7MamEVewOLMbGakppUn9lKN/8/ScD2m6bzsAyvY6juxX/Gflb2bkzLdYmXw2z3YfhFaVGHLFo+w85VTezxpB2tbfuGRTDmjtdmjI07CTdN0UvgjVGH4KsMXh8622x8pQSt2jlMpWSmXn5+eHKDQRbL5WyASiBNLduTR4TIi+zAPYDe3ZkobHD/PPZV9x+8pv+O4/g2l0YDc3XtTUHF9cDFu2wKJF3LZpEfcvmcjoyc/zd5UE7s0YxrHK8QDsTziF+67KpO7hfXw1/gk+nzSct6a/RtXjx1x+H96GnVzdoZyMQLzxivAUqrJM5eKxcsOgWuuPgI8A0tLS3A2TigjjS9mkL4un3HGcH6hkG9d2pWhrHgweDNdfD506lXveY9uDbdtg4UJYsICMBQvI+PVX85qV4tCV4pgzdTjV17SARzebZH/8OABP2758a816/DPjX+w65dQyL7uq0Zncct3zNDi0h9P2bWfIovEkH8jnmdteKBeCu5+jXZxy9WtWMbJJS3QL1RX+VqCJw+eNgW0hOrewmC9lkyfbK975zsBdsu+yIZtZnz4I778PXbvC66+D87EbNsBNN5k3hG0O/z0ffhhSUmDAABg3DpKT4YUXYMEC4g//TfXZ31G9/qlw9Ch07AhDh8KHH8J338GaNUxbso7rh00gp3ErUpISuOmipqV3EkkJ8fzUvB3TWnfh3fSBDO77OO23r2PS2Edh3TqPP8fKxcepd+jESKq7793xZ+Xtyl02aYluobrCnwYMVkpNwEzW7tdabw/RuYXFfOkS6ctdgCveyinji4t4bMFnDPpxCvtbnA2ffQuvvgqPPQbLlsEnn5hE/fzz8H//B/Hx5o1g1ix480247TaYNw/S0szzHTpAZadfm06dICfHbQx9gb4dW7h93vEOJadjT5ZndKTz0Lvgootg6lS45BLgxM9xyKSfKdaaN6e/zpVrFnLuA+PZm1iLFA8/K1+v3GWTlugWqLLML4EuQF2l1FbMnWw8gNZ6FDATU5K5HlOWeXsgzisih7cukUN7tiyTkMC3XvGeElFaSQHDxz9L++3r2Nj/Fk4fOwoSEmDyZHjtNcjMhJUrYc8eOHQI7roLnnnG/P3OO+GOO2DiRNi0CW65xSR9J4EoN3X5s+nUBq64Arp3h/Hj4dprS48FmD7yU65csxCAW1dOZ1S3Wzz+rHwdMjvZN14RGQKS8LXWA708r4H7A3EuEZ1Otle8Y4KqVXiQ/dVqgFK0jj/GV58/BYUFMHkyp1999YkvUsoMu5x/vknkl14KI0ZA69Ynjpk/Hz74wLwp/P03NG1a7txBHe9u0QKWLIG+feGGG6BOHejWzbx2w0r0mv0OGxs0Iy+hNrflTOf0kc9ypYdz+nrlfrJvvCIyKO1l3M8qaWlpOjs72+owRJizJ924QwdZ/v4tvN/xOj7tdD3zZ79M/Z9XmInW888/+RNs2mSGdh59FJo1K/NU+oi5Lq+GU5ISWJzZzef4Pb7JFRRAejrk5cGiRVC/PnTpAn/9Zd4QDh40Qz7vvAMPPOD2PPZYqxYdJb6kmENVE93GKj2CIptSaqXWuvztKNI8TUQ4eyKa8XEW1YuOcP/yr+jcrwv1f1xkkqA/yR6geXPzOi74O97t0x1CUhJ8+60Zz+/d23y+ebN5rF07c8zFF8Mbb0CbNrBrl/mzc+eJv+/axawt22DXLmocK6QExZLT2jG9XXcuefyecnHJJi3RS67wRXQYO9ZMsIJJiocOQX6++XuQ+HuFX6Gvz8mBzp2hqAimT4fLLjvx3NSpkJFR9vhKlaBePXNHYPuzgUTm7NEcP3iIvmsW0mTfdg5Urc7ETv1p9O8n6NPpbJ++bxHe5ApfRL81a0yFzYAB8NlnJiEGMdmD/+PdFbpDSE2FxYtNff+555Z9rm9fmD3bVA/ZE3ydOibpOzjD9icrJ48ek3+h9eZfuefHKdw951P2LZ7CvDGT6Dqwp0+xi8gkCV9EhzVr4MwzTX381KlmotMLX8aqPR3j76bkFa6IsQ/hOFOq7BW/FyNnraXweAkrG7dmUOPWtNmxntGTn6PRg4PosmEUD1/hvueQiGyS8EVA+ZtET9qaNabKpmlTM25dpYrXOL2Nn/tyjD/j3VZVxDjfQfzasAWZvR7g06+epf+MTxh29A5AVtZGI+mHLwLGl344Ad828PBhePBBk/BTU81jVauaq14PfFlRGuxVpxXp3xNIru4g5p9xPhPbXs4/l0/mrD9X8/0HE6B6dbNGQUQNucIXAePL4h6vx+zfbxJ2tWreT7h8uamj/+MPk/SHDPE5Vl/Gz0Ox6tTTHUKwyiNd3VkAvND9Li7ZnMvrM94kr1Z982b6/fdw3XV+n1OEB7nCFwHjdxI9cAAaNYLERDjjDLPSdMgQGD3a1KDv3m0OPnYMnnzSlCMeOWKS0ttvm1W0PvK25aGvxwSLqzuhRybm0iwAHSwd7ywcHaxanSf+8SAt9m7l0k22/Yk2bvTjuxDhRhK+CBi/k+iuXVBYCH36mPr5vDyz2vXuu02/mnr1oG5dswr1pZfg1lvhl19KV6BWhC8N3Xw5Jljc9fWHAAyDYZL+4sxuvHV9hzLf46LmqUw4t/eJAyXhRxVJ+CJg/E6iBw6YB+66CyZMgNxcU0+/cSPMnAlvvMGmS3uxJDGZu68eTvpZN5G
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = np.linspace(-4, 4, 100)\n",
|
||
|
"y = 1 - (x/3) ** 2\n",
|
||
|
"x1 = x + np.random.randn(100) * 0.1\n",
|
||
|
"y1 = y + np.random.randn(100) * 0.1\n",
|
||
|
"plt.scatter(x1,y1)\n",
|
||
|
"idx = x1.argsort()\n",
|
||
|
"beta,avg,res = 0.7,0,[]\n",
|
||
|
"for i in idx:\n",
|
||
|
" avg = beta * avg + (1-beta) * y1[i]\n",
|
||
|
" res.append(avg/(1-beta**(i+1)))\n",
|
||
|
"plt.plot(x1[idx],np.array(res), color='red');"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"It works particularly well if the loss function has narrow canyons we need to navigate: vanilla SGD would send us from one side to the other while SGD with momentum will average those to roll down inside. The parameter `beta` determines the strength of that momentum we are using: with a small beta we stay closer to the actual gradient values whereas with a high beta, we will mostly go in the direction of the average of the gradients and it will take a while before any change in the gradients makes that trend move.\n",
|
||
|
"\n",
|
||
|
"With a large beta, we might miss that the gradients have changed directions and roll over a small local minima which is a desired side-effect: intuitively, when we show a new picture/text/data to our model, it will look like something in the training set but won't be exactly like it. That means it will correspond to a point in the loss function that is closest to the minimum we ended up with at the end of training, but not exactly *at* that minimum. We then would rather end up training in a wide minimum, where nearby points have approximately the same loss (or if you prefer, a point where the loss is as flat as possible). Here's how the above chart varies as we change beta:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"hide_input": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtEAAAHiCAYAAAAuz5CZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iTVRvH8e9pKTRllQ0tG9kgFCpDcLBBVkV8AVFBRUDEiSjgHkABxYELRMEByrSioAwBlVFZBZElG9qyocxC13n/eJKSpkm60qRp78919Wrz5El6CumvJ+c55z5Ka40QQgghhBAi83w83QAhhBBCCCG8jXSihRBCCCGEyCLpRAshhBBCCJFF0okWQgghhBAii6QTLYQQQgghRBZJJ1oIIYQQQogskk608ApKqSNKqY6ebocQQoiMSWaLgkA60SJfy60gV4ZJSqlz5o/JSinl4Ny7lVIpSqkrVh+DXN0mIYTwdnkks8fZ5HW8OcPLurpdwrsV8nQDhPBSQ4EwoAmggZXAIeBzB+fHaq0ru6ltQggh0sp0ZmutJwATLLeVUm8Ad2qtz7qlpcJryEi08Ca3KaV2K6UuKKVmKaX8AZRSPZRS25VScUqpDUqpW83HvwWqAj+bRxNeNB9foJQ6qZS6qJT6UynVMBttGQS8p7WO1lrHAO8Bg13yUwohRP7g9ZltHq1+CPg6G99T5HPSiRbeZCDQBagF1AFeUUo1A74ChgFlgOnAEqVUEa31Q8AxoKfWupjWerL5eX4FagPlgW3AHMs3UEqNMQe73Q+rtjQEdljd3mE+5kh5pdQppdRhpdT7SqmiOfh3EEIIb+DNmW1xB1ABWJTln17ke0pr7ek2CJEhpdQRIFxr/bn59j3ANGAFcFZr/arVufuAoVrrP8yPG6K1XuXgeQOBC0Cg1vpiFtqTDDTUWu81364N/Af4aJtfKqVURaA0sBeohjGisUdrPSyz308IIbyJN2e2zeO+BHy11oMz+71EwSEj0cKbHLf6+igQhNEpHWUz8lDFfF86SilfpVS4UuqgUuoScMR8V1YXjFwBSljdLgFcsRfGWuuTWuvdWusUrfVh4EWgbxa/nxBCeBuvzGyr720C7kemcggHpBMtvEkVq6+rArEYIT1eax1o9RGgtf7efJ5tQD4A9AY6AiWB6ubjCuyuyk7zYfU8uzAWqFg0MR/LDG35fkIIkY95e2b3Ac4DazPzw4qCRzrRwps8qZSqrJQqDYwD5gFfAMOVUi2VoahSqrtSqrj5MaeAmlbPURy4AZwDArBagQ3GqmzzXDy7H1anfgM8r5QKVkoFAaOA2fYarYwSd1XN7asChAM/5fQfQwgh8jivzGwrg4BvnI1Wi4JNOtHCm8zFmE93yPzxjtZ6C/A48DHGPLkDpF1xPRFjMUucUuoFjCA9CsQAu4HIbLZlOvAzsBP4F1hqPgaAeRTkDvPNZsBG4CqwwXz+09n8vkII4S28NbNRSgUD7c3fXwi7ZGGhEEIIIYQQWSQj0UIIIYQQQmSRdKKFEEIIIYTIIpd0opVSXymlTiul/nVw/0Cl1D/mjw1KqSZW9x1RSu1Uxu5FW1zRHiGEEI5JZgshRM65aiR6NtDVyf2Hgbu01rcCbwMzbO5vp7VuqrUOdVF7hBBCODYbyWwhhMiRQq54Eq31n0qp6k7u32B1MxKo7IrvK4QQIusks4UQIudc0onOoseAX61ua2CFUkoD07XWtiMe6ZQtW1ZXr149l5onhBC5Z+vWrWe11uU83Y4skMwWQhRYzjLbrZ1opVQ7jEBua3W4jdY6VilVHliplNqrtf7TzmOHAkMBqlatypYtMhVPCOF9lFJHPd2GzJLMFkIUdM4y223VOZRStwIzgd5a63OW41rrWPPn08CPQAt7j9daz9Bah2qtQ8uV86ZBHCGE8D6S2UII4ZxbOtFKqarAYuAhrfV/VseLWrb6VEoVBTpj7CQkhBDCQySzhRAiYy6ZzqGU+h64GyirlIoGXgf8ALTWnwOvAWWAT5VSAEnmVd0VgB/NxwoBc7XWv7miTUIIIeyTzBZCiJxzVXWOARncPwQYYuf4IaBJ+kcIIYTILZLZQgiRc7JjoRBCCCGEEFnkiRJ3QniViKgYpizfR2xcPEGBJkZ3qUtYSLCnmyWEEMIOyWzhLtKJFsKJiKgYxi7eSXxiMgAxcfGMXbwTQEJZCCHyGMls4U4ynUMIJ6Ys35caxhbxiclMWb7PQy0SQgjhiGS2cCfpRAvhRGxcfJaOCyGE8BzJbOFO0okWwomgQFOWjgshhPAcyWzhTtKJFsKJ0V3qYvLzTXPM5OfL6C51PdQiIYQQjkhmC3eShYVCOGFZiOJopbesAhdCiLxDMlu4k3SihchAWEiw3ZCVVeBCCJH3SGYLd5FOtBDZ5GwVeFYCWUZGhBAi90lmC1eTTrQQ2eSKVeAyMiKEEO4hmS1cTRYWigIpIiqGNuGrqTFmKW3CVxMRFeP8AbNmwe7daQ65YhW41DQVQojMyXJu25DMFq4mnWhR4FhGEmLi4tHcHElwGMjR0fDoo9CkCbzwAly6BLhmFbjUNBVCiIxlObftkMwWruaSTrRS6iul1Gml1L8O7ldKqY+UUgeUUv8opZpZ3TdIKbXf/DHIFe0RBVdmRiqyPJKQmGh8rlcPpk41Ps+ZQ1jTICb2aUxwoAkFBAeamNincZYu6UlNU+EJktkir8js6LIrRoDDQoIls4VLuWpO9GzgY+AbB/d3A2qbP1oCnwEtlVKlgdeBUEADW5VSS7TWF1zULlGAZHauWpZHErQ2Pr/wAtSvD08+CQ8+CNOnE/bxx4SNaZ/tNo/uUjdNm0Fqmgq3mI1ktvCwrMwvdtUIsKPKHZklmS2suWQkWmv9J3DeySm9gW+0IRIIVEpVAroAK7XW580hvBLo6oo2iYInsyMVWR5JSEkxPvv4QIsW8PffMGOGMUe6SRMoXhyWLs1Wm10xMiJEVklmi7wgK6PLWc7t8eNhwAB45x1YvBj27YOkpBy3WTJbWHNXdY5g4LjV7WjzMUfHhXDKXomhzI5U2BtJUBijIG3CV6crV7Ry1wk6Ac/N38GmmGDj/scfZ1Xp2nTs2w6uXCGpVy92vhxOyFujM91ey/fI6ciIELlAMlu4VE4yG7KW2xFRMbSe8iGBV+Io8sMPqecn+xXmcJlg9pUMJrJZO0Kfe4zet1XLUpsteS2ZLcB9nWhl55h2cjz9Eyg1FBgKULVqVde1TOR5tkHWrl45Fm2NSXcJMDDAjwvXEtM93nakwnpHq5i4eBQ3X3S2lxMjomKYvnwfnYBkpVLv33L0PKs2nqGj+XG7ytci5O0X2XsylnrTp4JSqW1/Y8ku4uJvtismLp7n5m1ny9HzvBPW2DhomTKi7P1KCOF2ktkiR6xzu6TJj6sJSSQmGy+VrGY2ZD63AcYu3sma5GQWN2zHWx2G0vBiDP2LXSZu83aGbFzILScP033fOmKXfs6/gx+n0RsvQKlSadpt73uky21R4LmrOkc0UMXqdmUg1snxdLTWM7TWoVrr0HLlyuVaQ0XeYm9F9pzIY3YvAWptzE2rcPksShtTMBzNVQsLCWb9mPYEB5rS9QCsLydOWb6PhETLJUCVev/3fx/nrCqc+pgH+o9nfuOO1PviA3jsMThwgF9XbeeNeVuIu5aQ7vtrYE7ksZuLaEaMgHbtICH9uUJ4gGS2yDbb3I6LT0ztQFtYZ7Y1Z/OLM5PblikiPlqjlQ/xhf3ZUq4WLwWE8M6dg1PP/7RVXw6XqkSjjyZA5crQtSvH7+nDpWFP0veXL6l17ni675Eut0WB565O9BLgYfOK71bARa31CWA50FkpVUopVQrobD4mCoDsVtKwO+wFXIxPZFqzANZNH8KgbUszNVct9sI1msTuY8yar5iy9APKXI0zjpsvJ549e4mmsf+l+77JWpPke/NCztUiAbzY7Rk+aDPAqClduzbdOoWwfVIY/717L712/wGAb0oyQZdOpz7flOX74MQJmDkT/vgD3n7b2T+ZEO4imS3SyUklDXsuxidma36xs2kglvt
|
||
|
"text/plain": [
|
||
|
"<Figure size 864x576 with 4 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = np.linspace(-4, 4, 100)\n",
|
||
|
"y = 1 - (x/3) ** 2\n",
|
||
|
"x1 = x + np.random.randn(100) * 0.1\n",
|
||
|
"y1 = y + np.random.randn(100) * 0.1\n",
|
||
|
"_,axs = plt.subplots(2,2, figsize=(12,8))\n",
|
||
|
"betas = [0.5,0.7,0.9,0.99]\n",
|
||
|
"idx = x1.argsort()\n",
|
||
|
"for beta,ax in zip(betas, axs.flatten()):\n",
|
||
|
" ax.scatter(x1,y1)\n",
|
||
|
" avg,res = 0,[]\n",
|
||
|
" for i in idx:\n",
|
||
|
" avg = beta * avg + (1-beta) * y1[i]\n",
|
||
|
" res.append(avg)#/(1-beta**(i+1)))\n",
|
||
|
" ax.plot(x1[idx],np.array(res), color='red');\n",
|
||
|
" ax.set_title(f'beta={beta}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We can see in these examples that a beta that's too high results in the overall changes in gradient getting ignored. In SGD with momentum, a value of `beta` that is often used is 0.9.\n",
|
||
|
"\n",
|
||
|
"`fit_one_cycle` by default starts with a beta of 0.95, gradually adjusts it to 0.85, and then gradually moves it back to 0.95 at the end of training. Let's see how our training goes with momentum added to plain SGD:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In order to add momentum to our optimizer, we'll first need to keep track of the moving average gradient, which we can do with another callback. When an optimizer callback returns a dict, it is used to update the state of the optimizer, and is passed back to the optimizer on the next step. So this callback will keep track of the gradient averages in a parameter called `grad_avg`:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def average_grad(p, mom, grad_avg=None, **kwargs):\n",
|
||
|
" if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)\n",
|
||
|
" return {'grad_avg': grad_avg*mom + p.grad.data}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To use it, we just have to replace `p.grad.data` with `grad_avg` in our step function:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def momentum_step(p, lr, grad_avg, **kwargs): p.data.add_(-lr, grad_avg)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"opt_func = partial(Optimizer, cbs=[average_grad,momentum_step], mom=0.9)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"`Learner` will automatically schedule `mom` and `lr`, so fit_one_cycle will even work with our custom Optimizer:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: left;\">\n",
|
||
|
" <th>epoch</th>\n",
|
||
|
" <th>train_loss</th>\n",
|
||
|
" <th>valid_loss</th>\n",
|
||
|
" <th>accuracy</th>\n",
|
||
|
" <th>time</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" <td>2.856000</td>\n",
|
||
|
" <td>2.493429</td>\n",
|
||
|
" <td>0.246115</td>\n",
|
||
|
" <td>00:10</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>1</td>\n",
|
||
|
" <td>2.504205</td>\n",
|
||
|
" <td>2.463813</td>\n",
|
||
|
" <td>0.348280</td>\n",
|
||
|
" <td>00:10</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>2</td>\n",
|
||
|
" <td>2.187387</td>\n",
|
||
|
" <td>1.755670</td>\n",
|
||
|
" <td>0.418853</td>\n",
|
||
|
" <td>00:10</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn = get_learner(opt_func=opt_func)\n",
|
||
|
"learn.fit_one_cycle(3, 0.03)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAt0AAAD4CAYAAAAwyVpeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3yV9fn/8deVTfYkjEz2JpAQQETcolVRBAGto9pqtWqdraNaZx1UqbXan7TiqBNwoYI4EBVkJIywAyEJJKwkZJKdnM/vjxz8pjRAgJzcZ1zPx+M8PLnPfd95H4STK/d9fT4fMcaglFJKKaWUchwvqwMopZRSSinl7rToVkoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysF8rA7QGaKjo01SUpLVMZRS6oStWbOmxBgTY3WOzqSf2UopV3Wsz2yPKLqTkpLIzMy0OoZSSp0wEdlldYbOpp/ZSilXdazPbG0vUUoppZRSysG06FZKKaWUUsrBtOhWSimllFLKwbToVkoppZRSysG06FZKKaWUUsrBHFp0i8hEEckWkRwRub+N1/1F5AP766tEJMm+PV1E1tsfWSJyeXvPqZRSSimllLNxWNEtIt7Ay8CFwCBghogMOmK3G4EyY0wfYBbwrH37JiDNGJMCTAReFRGfdp5TKaWUUkopp+LIebrTgRxjTC6AiLwPTAK2tNpnEvCo/fl84B8iIsaYmlb7BADmBM6pOlh9UzMrc0vJ3l9JTUMzgX7eBPn7EOzvQ3igHwmRgfQM74Kfj3YrKaWsMevr7UQG+TEsLoyU+HBExOpISikXUVHTyP7KOg5U1lFUVU9pdT01Dc3MSE8gNjSgw76PI4vunkBBq68LgdFH28cY0yQiFUAUUCIio4E5QCJwjf319pwTABG5CbgJICEh4dTfjQdqthneWbWLv32zg9LqhmPu6yXQPawL/buFMCI+nBEJEQyPDyMkwLeT0iqlPFVjs43ZP+RS29gMwIBuITx66WDG9IqyOJlSytkUVdWxMreUzXsr2LK3kq37qig5VN/mvuP7xrhM0d3WZQbT3n2MMauAwSIyEHhTRBa185zYj58NzAZIS0trcx91dBW1jdz5/jq+yy5mbK8ofnNGMqmJkQT7+1Db2Ex1fRPV9U2UHGpgd2kNuw9Ws6u0hs17K1myrQgAEUiJD+fcgbGcNyiWvl2D9eqTUqrD+Xp7seXxCyiuqmfJtiJeXprD9Nkruff8fvzurD76uaOUB2tqtrEi9yDfZxezLKeEbfurAPDz9qJvbDBn9o+hf2wI3cMDiA0NoGuIP5FBfnTx9cbHu2Pv4Duy6C4E4lt9HQfsPco+hSLiA4QBpa13MMZsFZFqYEg7z6lOUWVdI9e8toqt+yp54rIh/HJ0wn/90Aq2t5YA9IqB9OTI/zq+oraRrIJyMneVsTS7iJmLs5m5OJuEyEAuS+nB1LR44iMDO/U9KaXcm4jQNTSA6ekJXJrSg4c+3sRfv9pOVV0TD1w00Op4SqlOZIxh7e5yFqzfwxcb91FyqAE/Hy9GJUXwx4kDOL1PNAO6h+DbwUX18Tiy6M4A+opIMrAHmA5cdcQ+C4DrgBXAFGCJMcbYjymwt5QkAv2BfKC8HedUp6Cx2cZNb2WyZW8lr16TyjkDY0/4HGFdfDmjXwxn9Ivh7vP6caCyjm+3FrFo0z5e+i6Hl77LYVzvaKaNimfikG6d/pdeKeXeAv18eOHK4YQE+PDqD7l0Dwvg+nHJVsdSSjlYdX0TH64t5I3l+eSWVOPv48W5A2O5ZHgPJvSLoYuft6X5HFZ02wvm24DFgDcwxxizWUQeBzKNMQuA14D/iEgOLVe4p9sPPx24X0QaARtwqzGmBKCtczrqPXiiZxdtY2VuKS9cOfykCu62xIYGcNXoBK4ancCe8lrmZRYwL7OQ299bR1xEF246oxdTU+Mt/8eglHIfIsKfLxnMvoo6nvhiK8PiwxmZEGF1LKWUA+wtr+X15Xm8n1FAVV0Tw+PDmTllGBOHdHOqsWVijPu3O6elpZnMzEyrYzi9H7YXc+2c1Vw3NpHHJg1x6Pey2Qzfbivin0tzWLu7nKggP244PZkbxiVr8a1UKyKyxhiTZnWOztSRn9mVdY1c9OKPiMCXvz+DIH9H3uBVSnWmoso6Xlm6k3dX7abZGC4c0o0bTk+29BfsY31m66ePAqCqrpH7P9xA75igTul/9PISzhsUy7kDu5KRX8YrS3OYuTibt1bkc/d5/ZiSGo+3lw5+UkqdmtAAX2ZNS2Hq/1vB377ZzkO/0KUdlHJ1FTWNvLI0hzdX5NPYbJiaGsdtZ/chLsK5x4tp0a0AeGlJDvsq65j/29MI8O28K80iQnpyJOnJ6WTkl/KXhVv544cbmbMsn4d+MZAz+sV0WhallHsalRTJ9FHxzFmezxWpcQzoFmp1JKXUSbDZDHMzC3hucTZlNQ1cltKT35/Tl6ToIKujtYuOYFPklVTz+vI8poyMIzXRulsyo5Ii+eiW03jl6pHUNTVz7ZzV3PHeOoqr2p4/Uyml2uuPEwcQ7O/D0wu3WR1FKXUSsgrKufyV5dz/0UZ6xwTxxe3jmTUtxWUKbtCiWwEzF2/Dz9uL+yb2tzoKIsJFQ7vz1V1ncOe5ffly037OeX4p76/ejc3m/uMPlFKOERHkx61n9ub77cX8tLPE6jhKqXaqa2zmic+3cNkry9lbUcesacOZe/NYBvVwvTtWWnR7uM17K1i4cT83np5M15COW3XpVPn7eHPnuf1Y+PvxDOweyv0fbeSaOavYV1FrdTSllIu67rQkuocF8PxX2/GESQSUcnVrd5dx0Ys/8tqyPK5KT2DJPRO4fEScyy54pUW3h/vbNzsIDfDhxvG9rI7Spj5dg3n/pjE8PXko63aXM/FvP/LFhn1Wx1JKuaAAX29+O6E3a3aVsTqv9PgHKKUs0dBk45lF25jyz5+ob7Lx9o2jeeryoU41/d/J0KLbg+UWH+LrLQe4flwyYV2c9y+yiDAjPYGFd4wnOTqI3727lrvnrudQfZPV0ZRSLmbaqHiig/34x3c5VkdRSrWhoLSGqa+u4P99v5Mr0+L58s7xnN432upYHUKLbg/2+vJ8/Ly9uGZMotVR2iUpOoh5vx3L78/pyyfr9nDpP5ax40CV1bGUUi4kwNebX41L5scdJeQU6eeHUs7kq837+cXffyS36BCvXD2SZ64Y5vJXt1vTottDVdQ0Mn9NIZem9CAmxN/qOO3m6+3FXef1451fj6GytpFJLy/ns6y9VsdSSrmQaaPi8fP24u2Vu62OopQCmpptPPn5Fm76zxoSo4L4/I7TuWhod6tjdTgtuj3Uexm7qW1s5oZxyVZHOSlje0fx+e0tgyxvf28dj3+2haZmm9WxlFIuIDrYnwuHduPDNYXUNGibmlJWqqhp5FdvZPDvZXlcOzaR+beMJTHKdaYBPBFadHugxmYbb/6Uz9heUS455c5h3cICeO83Y7j+tCTmLM/jV29kUFnXaHUspZQLuGZMIlX1TXy6Xu+UKWWVnKIqJr28jJW5B3nuimE8PmkI/j6dt0BfZ9Oi2wN9veUA+yrquPF017zK3ZqfjxePXjqYZyYPZcXOg1zxyk8UlNZYHUsptyIiE0UkW0RyROT+Nl5PFJFvRWSDiCwVkbgjXg8VkT0i8o/OS31sqYkRDOgWwn9W7NLpA5WywNLsIi5/+ScO1Tfx3m/GcOWoeKsjOZwW3R5oXmYB3cMCOGtAV6ujdJjp6Qm8dUM6ByrruPyV5azdXWZ1JKXcgoh4Ay8DFwKDgBkiMuiI3f4KvGWMGQY8Djx9xOtPAN87OuuJEBF+OSaRLfsqWVdQbnUcpTzK3IwCbnwzk/jIQD697XTSkiKtjtQptOj2MEWVdXy/vZjJI3vi7eWak8sfzWl9ovno1nEE+fswY/ZKvtlywOpISrmDdCDHGJNrjGkA3gcmHbHPIOBb+/PvWr8uIqlALPBVJ2Q9IZeN6EmgnzdzMwqsjqKURzDG8OI3O/jDhxs4rXcUc387lp7hXay
|
||
|
"text/plain": [
|
||
|
"<Figure size 864x288 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn.recorder.plot_sched()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We're still not getting great results, so let's see what else we can do."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## RMSProp"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"RMSProp is another variant of SGD introduced by Geoffrey Hinton in [Lecture 6e of his Coursera class](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). The main difference with SGD is that it uses an adaptive learning rate: instead of using the same learning rate for every parameter, each parameter gets it's own specific learning rate controlled by a global learning rate. That way we can speed up training by giving a high learning rate to the weights that needs to change a lot while the ones that are good enough get a lower learning rate.\n",
|
||
|
"\n",
|
||
|
"How do we decide which parameter should have a high learning rate and which should not? We can look at the gradients to get an idea. Not just the one we computed, but all of them: if they have been close to 0 for a while, it means this parameter will need a higher learning rate because the loss is very flat. On the opposite, if they are all over the place, we should probably be careful and pick a low learning rate to avoid divergence. We can't just average the gradients to see if they're changing a lot, since the average of a large positive and a large negative number is close to zero. So we can use the usual trick of either taking the absolute value, or the squared values (and then taking the square root after the mean).\n",
|
||
|
"\n",
|
||
|
"Once again, to pick the general tendency behind the noise, we will use a moving average, specifically the moving average of the gradients squared. Then, we will update the corresponding weight by using the current gradient (for the direction) divided by the square root of this moving average (that way if it's low, the effective learning rate will be higher, and if it's big, the effective learning rate will be lower).\n",
|
||
|
"\n",
|
||
|
"```python\n",
|
||
|
"w.square_avg = alpha * w.square_avg + (1-alpha) * (w.grad ** 2)\n",
|
||
|
"new_w = w - lr * w.grad / math.sqrt(w.square_avg + eps)\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"The `eps` (*epsilon*) is added for numerical stability (usually set at 1e-8) and the default value for `alpha` is usually 0.99."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We can add this to `Optimizer` by doing much the same thing we did for `avg_grad`, but with an extra `**2`:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
|
||
|
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
|
||
|
" return {'sqr_avg': sqr_avg*sqr_mom + p.grad.data**2}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"And we can define our step function and optimzer as before:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):\n",
|
||
|
" denom = sqr_avg.sqrt().add_(eps)\n",
|
||
|
" p.data.addcdiv_(-lr, p.grad, denom)\n",
|
||
|
"\n",
|
||
|
"opt_func = partial(Optimizer, cbs=[average_sqr_grad,rms_prop_step],\n",
|
||
|
" sqr_mom=0.99, eps=1e-7)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's try it out:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: left;\">\n",
|
||
|
" <th>epoch</th>\n",
|
||
|
" <th>train_loss</th>\n",
|
||
|
" <th>valid_loss</th>\n",
|
||
|
" <th>accuracy</th>\n",
|
||
|
" <th>time</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" <td>2.766912</td>\n",
|
||
|
" <td>1.845900</td>\n",
|
||
|
" <td>0.402548</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>1</td>\n",
|
||
|
" <td>2.194586</td>\n",
|
||
|
" <td>1.510269</td>\n",
|
||
|
" <td>0.504459</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>2</td>\n",
|
||
|
" <td>1.869099</td>\n",
|
||
|
" <td>1.447939</td>\n",
|
||
|
" <td>0.544968</td>\n",
|
||
|
" <td>00:11</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"learn = get_learner(opt_func=opt_func)\n",
|
||
|
"learn.fit_one_cycle(3, 0.003)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Much better! Now we just have to bring these ideas together, and we have Adam, fastai's default optimizer."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Adam"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Adam mixes the ideas of SGD with momentum and RMSProp together: it uses the moving average of the gradients as a direction and divides by the square root of the moving average of the gradients squared to give an adaptive learning rate to each parameter.\n",
|
||
|
"\n",
|
||
|
"There is one other difference with how Adam calculates moving averages, is that it takes the *unbiased* moving average which is:\n",
|
||
|
"\n",
|
||
|
"``` python\n",
|
||
|
"w.avg = beta * w.avg + (1-beta) * w.grad\n",
|
||
|
"unbias_avg = w.avg / (1 - (beta**(i+1)))\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"if we are the `i`-th iteration (starting at 0 like python does). This divisor of `1 - (beta**(i+1))` makes sure the unbiased average looks more like the gradients at the beginning (since `beta < 1` the denominator is very quickly very close to 1).\n",
|
||
|
"\n",
|
||
|
"Putting everything together, our update step looks like:\n",
|
||
|
"``` python\n",
|
||
|
"w.avg = beta1 * w.avg + (1-beta1) * w.grad\n",
|
||
|
"unbias_avg = w.avg / (1 - (beta1**(i+1)))\n",
|
||
|
"w.sqr_avg = beta2 * w.sqr_avg + (1-beta2) * (w.grad ** 2)\n",
|
||
|
"new_w = w - lr * unbias_avg / sqrt(w.sqr_avg + eps)\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"Like for RMSProp, `eps` is usually set to 1e-8, and the default for `(beta1,beta2)` suggested by the literature `(0.9,0.999)`. \n",
|
||
|
"\n",
|
||
|
"In fastai, Adam is the default optimizer we use since it allows faster training, but we found that `beta2=0.99` is better suited for the type of schedule we are using. `beta1` is the momentum parameter, which we specify with the argument `moms` in our call to `fit_one_cycle`. As for `eps`, fastai uses a default of 1e-5. `eps` is not just useful for numerical stability. A higher `eps` limits the maximum value of the adjusted learning rate. To take an extreme example, if `eps` is 1, then the adjusted learning will never be higher than the base learning rate. \n",
|
||
|
"\n",
|
||
|
"Rather than show all the code for this in the book, we'll let you look at the optimizer notebook in fastai's GitHub repository--you'll see all the code we've seen so far, along with Adam and other optimizers, and lots of examples and tests."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Decoupled weight_decay"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We've discussed weight decay before, which is equivalent to (in the case of vanilla SGD) updating the parameters\n",
|
||
|
"with:\n",
|
||
|
"\n",
|
||
|
"``` python\n",
|
||
|
"new_weight = weight - lr*weight.grad - lr*wd*weight\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"This last formula explains why the name of this technique is weight decay, as each weight is decayed by a factor `lr * wd`. \n",
|
||
|
"\n",
|
||
|
"However, this only works correctly for standard SGD, because we have seen that with momentum, RMSProp or in Adam, the update has some additional formulas around the gradient. In those cases, the formula that comes from L2 regularization:\n",
|
||
|
"\n",
|
||
|
"``` python\n",
|
||
|
"weight.grad += wd*weight\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"is different than weight decay:\n",
|
||
|
"\n",
|
||
|
"``` python\n",
|
||
|
"new_weight = weight - lr*weight.grad - lr*wd*weight\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"Most libraries use the first formulation, but it was pointed out in [Decoupled Weight Regularization](https://arxiv.org/pdf/1711.05101.pdf) by Ilya Loshchilov and Frank Hutter, second one is the only correct approach with the Adam optimizer or momentum, which is why fastai makes it its default.\n",
|
||
|
"\n",
|
||
|
"Now you know everything that is hidden behind the line `learn.fit_one_cycle`!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Questionnaire"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"1. What is the equation for a step of SGD, in math or code (as you prefer)?\n",
|
||
|
"1. What do we pass to `cnn_learner` to use a non-default optimizer?\n",
|
||
|
"1. What are optimizer callbacks?\n",
|
||
|
"1. What does `zero_grad` do in an optimizer?\n",
|
||
|
"1. What does `step` do in an optimizer? How is it implemented in the general optimizer?\n",
|
||
|
"1. Rewrite `sgd_cb` to use the `+=` operator, instead of `add_`.\n",
|
||
|
"1. What is momentum? Write out the equation.\n",
|
||
|
"1. What's a physical analogy for momentum? How does it apply in our model training settings?\n",
|
||
|
"1. What does a bigger value for momentum do to the gradients?\n",
|
||
|
"1. What are the default values of momentum for 1cycle training?\n",
|
||
|
"1. What is RMSProp? Write out the equation.\n",
|
||
|
"1. What do the squared values of the gradients indicate?\n",
|
||
|
"1. How does Adam differ from momentum and RMSProp?\n",
|
||
|
"1. Write out the equation for Adam.\n",
|
||
|
"1. Calculate the value of `unbias_avg` and `w.avg` for a few batches of dummy values.\n",
|
||
|
"1. What's the impact of having a high eps in Adam?\n",
|
||
|
"1. Read through the optimizer notebook in fastai's repo, and execute it.\n",
|
||
|
"1. In what situations do dynamic learning rate methods like Adam change the behaviour of weight decay?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Further research"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"1. Look up the \"rectified Adam\" paper and implement it using the general optimizer framework, and try it out. Search for other recent optimizers that work well in practice, and pick one to implement."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"jupytext": {
|
||
|
"split_at_heading": true
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.5"
|
||
|
},
|
||
|
"toc": {
|
||
|
"base_numbering": 1,
|
||
|
"nav_menu": {},
|
||
|
"number_sections": false,
|
||
|
"sideBar": true,
|
||
|
"skip_h1_title": true,
|
||
|
"title_cell": "Table of Contents",
|
||
|
"title_sidebar": "Contents",
|
||
|
"toc_cell": false,
|
||
|
"toc_position": {},
|
||
|
"toc_section_display": true,
|
||
|
"toc_window_display": false
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|