From 56a2fd6796e83d3b57b100096d501e9fe1ddf46d Mon Sep 17 00:00:00 2001 From: Jun Chen Date: Wed, 25 Oct 2023 00:39:26 -0700 Subject: [PATCH 1/4] Rename MiniGPTv2_Train .md to MiniGPTv2_Train.md --- MiniGPTv2_Train .md => MiniGPTv2_Train.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename MiniGPTv2_Train .md => MiniGPTv2_Train.md (100%) diff --git a/MiniGPTv2_Train .md b/MiniGPTv2_Train.md similarity index 100% rename from MiniGPTv2_Train .md rename to MiniGPTv2_Train.md From ada58a6f210793a82cd228c7aa2882b0b4f65f6c Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Mon, 30 Oct 2023 21:48:25 +0300 Subject: [PATCH 2/4] fix the compatibility issue when low-resource is false in the inference mode. Small refator in the conversation.py to remove repeated codes --- environment.yml | 2 ++ minigpt4/conversation/conversation.py | 32 +++++++++------------------ minigpt4/models/minigpt_base.py | 3 ++- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/environment.yml b/environment.yml index 8f94afe..aee52f8 100644 --- a/environment.yml +++ b/environment.yml @@ -30,4 +30,6 @@ dependencies: - gradio==3.47.1 - accelerate==0.20.3 - bitsandbytes==0.37.0 + - scikit-image + - visual-genome - wandb diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 6e4a5c1..73a995e 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -151,7 +151,8 @@ class Chat: def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) - embs = self.get_context_emb(conv, img_list) + prompt = conv.get_prompt() + embs = self.model.get_context_emb(prompt, img_list) current_max_len = embs.shape[1] + max_new_tokens if current_max_len - max_length > 0: @@ -176,8 +177,7 @@ class Chat: def answer(self, conv, img_list, **kargs): generation_dict = self.answer_prepare(conv, img_list, **kargs) - - output_token = self.model.llama_model.generate(**generation_dict)[0] + output_token = self.model_generate(**generation_dict)[0] output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) output_text = output_text.split('###')[0] # remove the stop sign '###' @@ -190,10 +190,16 @@ class Chat: generation_kwargs = self.answer_prepare(conv, img_list, **kargs) streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) generation_kwargs['streamer'] = streamer - thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) + thread = Thread(target=self.model_generate, kwargs=generation_kwargs) thread.start() return streamer + def model_generate(self, *args, **kwargs): + # for 8 bit and 16 bit compatibility + with self.model.maybe_autocast(): + output = self.model.llama_model.generate(*args, **kwargs) + return output + def encode_img(self, img_list): image = img_list[0] img_list.pop(0) @@ -218,21 +224,3 @@ class Chat: return msg - def get_context_emb(self, conv, img_list): - prompt = conv.get_prompt() - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.model.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - print('debug device: ', self.device) - print('debug model device: ', self.model.device) - seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index cd051ec..4ded50f 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -9,6 +9,8 @@ from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel from transformers import StoppingCriteria, StoppingCriteriaList +from minigpt4.conversation.conversation import StoppingCriteriaSub + class MiniGPTBase(BaseModel): @@ -314,7 +316,6 @@ class MiniGPTBase(BaseModel): embeds = self.llama_model.base_model.embed_tokens(token_ids) return embeds - @torch.no_grad() def generate( self, From 9d41619968aa761aadb4afcde514844a9676cda6 Mon Sep 17 00:00:00 2001 From: Jun Chen Date: Mon, 30 Oct 2023 13:28:39 -0700 Subject: [PATCH 3/4] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d24923d..c8a80ab 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Download the pretrained model checkpoints | MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo)| |------------------------------|------------------------------|------------------------------| -| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1jAbxUiyl04SFJMN4sF1vvUU69Etuz4qa/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) | +| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) | For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file From ca2314c206e6a5bcfc17e97ea5ecc3167583afa4 Mon Sep 17 00:00:00 2001 From: ZhuDeyao Date: Tue, 31 Oct 2023 13:59:22 +0300 Subject: [PATCH 4/4] Update stop criteria --- minigpt4/conversation/conversation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 73a995e..83c8bee 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop):])).item(): + if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False