From 6a8e6e49ffd24677742c7dd67370f3a49e0df7f2 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Mon, 17 Apr 2023 13:52:09 +0800 Subject: [PATCH] load in 8bit and run ViT-L in cpu to reduce vram use --- minigpt4/models/mini_gpt4.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 49dd467..5c223fe 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -84,7 +84,8 @@ class MiniGPT4(Blip2Base): self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, torch_dtype=torch.float16 + llama_model, torch_dtype=torch.float16, + load_in_8bit=True, device_map="auto" ) for name, param in self.llama_model.named_parameters(): param.requires_grad = False @@ -107,12 +108,17 @@ class MiniGPT4(Blip2Base): self.prompt_list = [] def encode_img(self, image): - with self.maybe_autocast(): - image_embeds = self.ln_vision(self.visual_encoder(image)) - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( - image.device - ) + device = image.device + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + image = image.to("cpu") + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + with self.maybe_autocast(): query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens,