mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
fix warmup LR
LR resets every epoch during warmup to `warmup_start_lr` because it uses inner epoch step counter. We need to use `total_cur_step` to increment LR continuously from epoch to epoch.
This commit is contained in:
parent
22d8888ca2
commit
f90ae217ee
@ -80,7 +80,7 @@ class LinearWarmupCosineLRScheduler:
|
|||||||
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
||||||
if total_cur_step < self.warmup_steps:
|
if total_cur_step < self.warmup_steps:
|
||||||
warmup_lr_schedule(
|
warmup_lr_schedule(
|
||||||
step=cur_step,
|
step=total_cur_step,
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
max_step=self.warmup_steps,
|
max_step=self.warmup_steps,
|
||||||
init_lr=self.warmup_start_lr,
|
init_lr=self.warmup_start_lr,
|
||||||
|
Loading…
Reference in New Issue
Block a user