Update minigpt4.py

This commit is contained in:
3luka 2024-08-16 15:33:52 +03:00
parent b6d44705b9
commit 8709e9966d

View File

@ -189,7 +189,7 @@ class MiniGPT4(MiniGPTBase):
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cuda")
ckpt = torch.load(ckpt_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model