From bc72a1adeaa0ae64ae64ecabec212233ae680388 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Mon, 17 Apr 2023 10:30:55 +0800 Subject: [PATCH 1/4] fix dependency error --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 552fe3a..d5cfcf8 100644 --- a/environment.yml +++ b/environment.yml @@ -25,7 +25,7 @@ dependencies: - filelock==3.9.0 - fonttools==4.38.0 - frozenlist==1.3.3 - - huggingface-hub==0.12.1 + - huggingface-hub==0.13.4 - importlib-resources==5.12.0 - kiwisolver==1.4.4 - matplotlib==3.7.0 From 017fa43d033ed9af27041192aa2c5712141c2c40 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Mon, 17 Apr 2023 13:51:39 +0800 Subject: [PATCH 2/4] disable beam search to reduce vram use --- minigpt4/conversation/conversation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index b0c2711..1a1221e 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -135,14 +135,15 @@ class Chat: conv.append_message(conv.roles[0], text) def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1): + repetition_penalty=1.0, length_penalty=1, temperature=1.0): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) outputs = self.model.llama_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, - num_beams=num_beams, + #num_beams=num_beams, + do_sample=True, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, From 6a8e6e49ffd24677742c7dd67370f3a49e0df7f2 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Mon, 17 Apr 2023 13:52:09 +0800 Subject: [PATCH 3/4] load in 8bit and run ViT-L in cpu to reduce vram use --- minigpt4/models/mini_gpt4.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 49dd467..5c223fe 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -84,7 +84,8 @@ class MiniGPT4(Blip2Base): self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, torch_dtype=torch.float16 + llama_model, torch_dtype=torch.float16, + load_in_8bit=True, device_map="auto" ) for name, param in self.llama_model.named_parameters(): param.requires_grad = False @@ -107,12 +108,17 @@ class MiniGPT4(Blip2Base): self.prompt_list = [] def encode_img(self, image): - with self.maybe_autocast(): - image_embeds = self.ln_vision(self.visual_encoder(image)) - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( - image.device - ) + device = image.device + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + image = image.to("cpu") + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + with self.maybe_autocast(): query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, From 700f05d0794d54441873b86efe19306524a8730a Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Mon, 17 Apr 2023 14:56:18 +0800 Subject: [PATCH 4/4] set num_beam default to 1 instead of completely removing it --- demo.py | 4 ++-- minigpt4/conversation/conversation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/demo.py b/demo.py index 8ceea3f..5445bfb 100644 --- a/demo.py +++ b/demo.py @@ -113,8 +113,8 @@ with gr.Blocks() as demo: num_beams = gr.Slider( minimum=1, - maximum=16, - value=5, + maximum=10, + value=1, step=1, interactive=True, label="beam search numbers)", diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 1a1221e..84aea9a 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -134,7 +134,7 @@ class Chat: else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9, + def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) @@ -142,7 +142,7 @@ class Chat: inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, - #num_beams=num_beams, + num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p,