16_accel_sgd: Fixed average_sqr_grad function (#342)

Co-authored-by: Kartikeya <kartkeyabhardwaj98@gmail.com>
This commit is contained in:
Kartikeya Bhardwaj 2020-11-29 19:32:32 +05:30 committed by GitHub
parent 4b4d127083
commit 2d72ffdee1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -746,7 +746,7 @@
"source": [ "source": [
"def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n", "def average_sqr_grad(p, sqr_mom, sqr_avg=None, **kwargs):\n",
" if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n", " if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n",
" return {'sqr_avg': sqr_avg*sqr_mom + p.grad.data**2}" " return {'sqr_avg': sqr_mom*sqr_avg + (1-sqr_mom)*p.grad.data**2}"
] ]
}, },
{ {
@ -1321,4 +1321,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }