diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py
index 667edd5..a078ca0 100644
--- a/minigpt4/models/mini_gpt4.py
+++ b/minigpt4/models/mini_gpt4.py
@@ -144,7 +144,7 @@ class MiniGPT4(Blip2Base):
             )
 
             inputs_llama = self.llama_proj(query_output.last_hidden_state)
-            atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
+            atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(device)
         return inputs_llama, atts_llama
 
     def prompt_wrap(self, img_embeds, atts_img, prompt):