From b6d44705b931ccaeae1435eaf66efd0114b8a7be Mon Sep 17 00:00:00 2001 From: 3luka <94873742+MohamedAlaaAli@users.noreply.github.com> Date: Fri, 16 Aug 2024 13:35:52 +0300 Subject: [PATCH] change to cuda --- minigpt4/models/minigpt4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py index 7da627b..908562b 100644 --- a/minigpt4/models/minigpt4.py +++ b/minigpt4/models/minigpt4.py @@ -189,7 +189,7 @@ class MiniGPT4(MiniGPTBase): ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 if ckpt_path: print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) - ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt = torch.load(ckpt_path, map_location="cuda") msg = model.load_state_dict(ckpt['model'], strict=False) return model