Update runner_base.py

This commit is contained in:
Jun Chen 2023-04-23 15:49:37 +03:00 committed by GitHub
parent d93f715165
commit 3bd99950f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -627,14 +627,14 @@ class RunnerBase:
cached_file = download_cached_file( cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True url_or_filename, check_hash=False, progress=True
) )
checkpoint = torch.load(cached_file, map_location=self.device, strict=False) checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename): elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False) checkpoint = torch.load(url_or_filename, map_location=self.device)
else: else:
raise RuntimeError("checkpoint url or path is invalid") raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict) self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
self.optimizer.load_state_dict(checkpoint["optimizer"]) self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint: if self.scaler and "scaler" in checkpoint: