diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py index 7da627b..908562b 100644 --- a/minigpt4/models/minigpt4.py +++ b/minigpt4/models/minigpt4.py @@ -189,7 +189,7 @@ class MiniGPT4(MiniGPTBase): ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 if ckpt_path: print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) - ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt = torch.load(ckpt_path, map_location="cuda") msg = model.load_state_dict(ckpt['model'], strict=False) return model