diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index b7298bb..b653eb7 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -5,7 +5,7 @@ model: end_sym: "###" low_resource: True prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth' + ckpt: '/path/to/checkpoint/' datasets: diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml index d30d03a..eea99d3 100644 --- a/eval_configs/minigpt4_llama2_eval.yaml +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -5,7 +5,7 @@ model: end_sym: "" low_resource: True prompt_template: '[INST] {} [/INST] ' - ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth' + ckpt: '/path/to/checkpoint/' datasets: diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index 1cd2226..2a13393 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -24,7 +24,7 @@ class BaseModel(nn.Module): @property def device(self): - return list(self.parameters())[0].device + return list(self.parameters())[-1].device def load_checkpoint(self, url_or_filename): """