From b76d5c5a50f3902c8806a1754d2466ca2309abee Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Mon, 17 Apr 2023 16:54:00 +0300 Subject: [PATCH] add argument to switch 8bit --- eval_configs/minigpt4_eval.yaml | 1 + minigpt4/models/mini_gpt4.py | 35 ++++++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index 5ac9fad..f9e55a3 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -5,6 +5,7 @@ model: freeze_qformer: True max_txt_len: 160 end_sym: "###" + low_resource: True prompt_path: "prompts/alignment.txt" prompt_template: '###Human: {} ###Assistant: ' ckpt: '/path/to/pretrained/ckpt/' diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 5c223fe..db88916 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -36,11 +36,13 @@ class MiniGPT4(Blip2Base): prompt_path="", prompt_template="", max_txt_len=32, + low_resource=False, # use 8 bit and put vit in cpu end_sym='\n', ): super().__init__() self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource print('Loading VIT') self.visual_encoder, self.ln_vision = self.init_vision_encoder( @@ -83,10 +85,19 @@ class MiniGPT4(Blip2Base): self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, torch_dtype=torch.float16, - load_in_8bit=True, device_map="auto" - ) + if self.low_resource: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map="auto" + ) + else: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + for name, param in self.llama_model.named_parameters(): param.requires_grad = False print('Loading LLAMA Done') @@ -107,18 +118,22 @@ class MiniGPT4(Blip2Base): else: self.prompt_list = [] - def encode_img(self, image): - device = image.device + def vit_to_cpu(self): self.ln_vision.to("cpu") self.ln_vision.float() self.visual_encoder.to("cpu") self.visual_encoder.float() - image = image.to("cpu") - image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + def encode_img(self, image): + device = image.device + if self.low_resource: + self.vit_to_cpu() + image = image.to("cpu") with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, @@ -216,6 +231,7 @@ class MiniGPT4(Blip2Base): vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) prompt_path = cfg.get("prompt_path", "") prompt_template = cfg.get("prompt_template", "") @@ -236,6 +252,7 @@ class MiniGPT4(Blip2Base): prompt_path=prompt_path, prompt_template=prompt_template, max_txt_len=max_txt_len, + low_resource=low_resource, end_sym=end_sym )