mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-05 18:30:44 +00:00
16_accel_sgd: Fixed average_sqr_grad function (#342)
Co-authored-by: Kartikeya <kartkeyabhardwaj98@gmail.com>
This commit is contained in:
parent
4b4d127083
commit
2d72ffdee1
@ -746,7 +746,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
|
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
|
||||||
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
|
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
|
||||||
" return {'sqr_avg': sqr_avg*sqr_mom + p.grad.data**2}"
|
" return {'sqr_avg': sqr_mom*sqr_avg + (1-sqr_mom)*p.grad.data**2}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1321,4 +1321,4 @@
|
|||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 2
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user