mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
Update runner_base.py
This commit is contained in:
parent
d93f715165
commit
3bd99950f0
@ -627,14 +627,14 @@ class RunnerBase:
|
|||||||
cached_file = download_cached_file(
|
cached_file = download_cached_file(
|
||||||
url_or_filename, check_hash=False, progress=True
|
url_or_filename, check_hash=False, progress=True
|
||||||
)
|
)
|
||||||
checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
|
checkpoint = torch.load(cached_file, map_location=self.device)
|
||||||
elif os.path.isfile(url_or_filename):
|
elif os.path.isfile(url_or_filename):
|
||||||
checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
|
checkpoint = torch.load(url_or_filename, map_location=self.device)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("checkpoint url or path is invalid")
|
raise RuntimeError("checkpoint url or path is invalid")
|
||||||
|
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
|
self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
|
||||||
|
|
||||||
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
if self.scaler and "scaler" in checkpoint:
|
if self.scaler and "scaler" in checkpoint:
|
||||||
|
Loading…
Reference in New Issue
Block a user