diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py index d794ed0..ccb5706 100644 --- a/minigpt4/runners/runner_base.py +++ b/minigpt4/runners/runner_base.py @@ -627,14 +627,14 @@ class RunnerBase: cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) - checkpoint = torch.load(cached_file, map_location=self.device, strict=False) + checkpoint = torch.load(cached_file, map_location=self.device) elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False) + checkpoint = torch.load(url_or_filename, map_location=self.device) else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] - self.unwrap_dist_model(self.model).load_state_dict(state_dict) + self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) self.optimizer.load_state_dict(checkpoint["optimizer"]) if self.scaler and "scaler" in checkpoint: