change to cuda

This commit is contained in:
3luka 2024-08-16 13:35:52 +03:00
parent 79ead58faf
commit b6d44705b9

View File

@ -189,7 +189,7 @@ class MiniGPT4(MiniGPTBase):
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt = torch.load(ckpt_path, map_location="cuda")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model