mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 11:00:48 +00:00
change to cuda
This commit is contained in:
parent
79ead58faf
commit
b6d44705b9
@ -189,7 +189,7 @@ class MiniGPT4(MiniGPTBase):
|
|||||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
|
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
|
||||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
ckpt = torch.load(ckpt_path, map_location="cuda")
|
||||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user