diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 667edd5..a078ca0 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -144,7 +144,7 @@ class MiniGPT4(Blip2Base): ) inputs_llama = self.llama_proj(query_output.last_hidden_state) - atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(device) return inputs_llama, atts_llama def prompt_wrap(self, img_embeds, atts_img, prompt):