From 2d2d781469a698e725da606d7e3d18fe66ae20cb Mon Sep 17 00:00:00 2001 From: unknown <913556700@qq.com> Date: Wed, 24 May 2023 00:21:43 +0800 Subject: [PATCH] modify dataset & dataloader; modify the evaluation part --- demo.py | 45 ++++++++++++--------- eval_configs/bindgpt4_eval.yaml | 26 +++++++++++- minigpt4/common/config.py | 2 + minigpt4/conversation/conversation.py | 36 +++++++++-------- minigpt4/models/bind_gpt4.py | 7 +--- temp.py | 0 train_configs/bindgpt4_stage1_pretrain.yaml | 4 +- 7 files changed, 78 insertions(+), 42 deletions(-) create mode 100644 temp.py diff --git a/demo.py b/demo.py index b3659f1..d441571 100644 --- a/demo.py +++ b/demo.py @@ -28,8 +28,8 @@ def parse_args(): "--options", nargs="+", help="override some settings in the used config, the key-value pair " - "in xxx=yyy format will be merged into config file (deprecate), " - "change to --cfg-options instead.", + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", ) args = parser.parse_args() return args @@ -64,6 +64,7 @@ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) print('Initialization Finished') + # ======================================== # Gradio Setting # ======================================== @@ -73,7 +74,10 @@ def gradio_reset(chat_state, img_list): chat_state.messages = [] if img_list is not None: img_list = [] - return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list + return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', + interactive=False), gr.update( + value="Upload & Start Chat", interactive=True), chat_state, img_list + def upload_img(gr_img, text_input, chat_state): if gr_img is None: @@ -81,7 +85,9 @@ def upload_img(gr_img, text_input, chat_state): chat_state = CONV_VISION.copy() img_list = [] llm_message = chat.upload_img(gr_img, chat_state, img_list) - return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update( + value="Start Chatting", interactive=False), chat_state, img_list + def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: @@ -101,33 +107,34 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): chatbot[-1][1] = llm_message return chatbot, chat_state, img_list -title = """

Demo of MiniGPT-4

""" -description = """

This is the demo of MiniGPT-4. Upload your images and start chatting!

""" -article = """

-""" -#TODO show examples below +title = """

Demo of BindGPT-4

""" +description = """

This is the demo of BindGPT-4. Upload your images and start chatting!

""" +# article = """

+# """ + +# TODO show examples below with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) - gr.Markdown(article) + # gr.Markdown(article) with gr.Row(): with gr.Column(scale=0.5): image = gr.Image(type="pil") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") - + num_beams = gr.Slider( minimum=1, maximum=10, value=1, step=1, interactive=True, - label="beam search numbers)", + label="beam search numbers", ) - + temperature = gr.Slider( minimum=0.1, maximum=2.0, @@ -140,14 +147,16 @@ with gr.Blocks() as demo: with gr.Column(): chat_state = gr.State() img_list = gr.State() - chatbot = gr.Chatbot(label='MiniGPT-4') + chatbot = gr.Chatbot(label='BindGPT-4') text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) - - upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list]) - + + upload_button.click(upload_img, [image, text_input, chat_state], + [image, text_input, upload_button, chat_state, img_list]) + text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] ) - clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False) + clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], + queue=False) demo.launch(share=True, enable_queue=True) diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml index 3a28b04..73d24e5 100644 --- a/eval_configs/bindgpt4_eval.yaml +++ b/eval_configs/bindgpt4_eval.yaml @@ -1 +1,25 @@ -# TODO: Finish the eval config of ImageBindGPT4 \ No newline at end of file +model: + arch: bind_gpt4 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: False + max_txt_len: 160 + end_sym: "###" + low_resource: False + prompt_path: "prompts/alignment.txt" + prompt_template: '###Human: {} ###Assistant: ' + ckpt: '/path/to/pretrained/ckpt/' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "imagebind_vision_eval" + image_size: 224 + text_processor: + train: + name: "imagebind_caption" + +run: + task: image_text_pretrain diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py index e184b1f..39db8e0 100644 --- a/minigpt4/common/config.py +++ b/minigpt4/common/config.py @@ -12,6 +12,8 @@ from typing import Dict from omegaconf import OmegaConf from minigpt4.common.registry import registry +# logging.info = print + class Config: def __init__(self, args): diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 676d89f..86d8b80 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -1,16 +1,12 @@ -import argparse -import time -from PIL import Image - -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer -from transformers import StoppingCriteria, StoppingCriteriaList - import dataclasses from enum import auto, Enum -from typing import List, Tuple, Any +from typing import List, Any -from minigpt4.common.registry import registry +import torch +from PIL import Image +from transformers import StoppingCriteria, StoppingCriteriaList + +from imagebind.models.image_bind import ModalityType class SeparatorStyle(Enum): @@ -107,7 +103,7 @@ class StoppingCriteriaSub(StoppingCriteria): CONV_VISION = Conversation( - system="Give the following image: ImageContent. " + system="Give the following image: ImageContent. " "You will be able to see the image once I provide it to you. Please answer my questions.", roles=("Human", "Assistant"), messages=[], @@ -116,6 +112,7 @@ CONV_VISION = Conversation( sep="###", ) +# TODO: If needed and possible, rewrite this file and re-organize the definition of components. class Chat: @@ -128,14 +125,18 @@ class Chat: self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def ask(self, text, conv): + # NOTE: the hard code for postfix is removed. + # TODO: Need to be compatible with more modalities. + end_token = '' if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ - and conv.messages[-1][1][-6:] == '': # last message is image. + and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image. conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) else: conv.append_message(conv.roles[0], text) def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + # Generate an answer written by LLaMA conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) @@ -160,7 +161,7 @@ class Chat: temperature=temperature, ) output_token = outputs[0] - if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it + if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] @@ -171,6 +172,7 @@ class Chat: return output_text, output_token.cpu().numpy() def upload_img(self, image, conv, img_list): + # Upload Image, Encode Image and Create a new message from human. if isinstance(image, str): # is a image path raw_image = Image.open(image).convert('RGB') image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) @@ -182,16 +184,18 @@ class Chat: image = image.unsqueeze(0) image = image.to(self.device) - image_emb, _ = self.model.encode_img(image) + all_embeddings = self.model.encode_inputs({ModalityType.VISION: image}) + image_emb = all_embeddings[ModalityType.VISION] img_list.append(image_emb) - conv.append_message(conv.roles[0], "") + conv.append_message(conv.roles[0], "") msg = "Received." # self.conv.append_message(self.conv.roles[1], msg) return msg def get_context_emb(self, conv, img_list): + # Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor] prompt = conv.get_prompt() - prompt_segs = prompt.split('') + prompt_segs = prompt.split('') assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." seg_tokens = [ self.model.llama_tokenizer( diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py index 5ec79d7..b400c7e 100644 --- a/minigpt4/models/bind_gpt4.py +++ b/minigpt4/models/bind_gpt4.py @@ -47,7 +47,8 @@ class BindGPT4(BaseModel): print('Loading Q-Former and Adapter/Projector') self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token, - vision_qformer_frozen=freeze_qformer + vision_qformer_frozen=freeze_qformer, + vision_post_dims=[768, self.llama_model.config.hidden_size] # vision_qformer_model=q_former_model, # vision_pre_dims=(1280, 1408) ) @@ -66,9 +67,6 @@ class BindGPT4(BaseModel): param.requires_grad = False print('Loading LLAMA Done') - # TODO: remove hard-coding - self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size) - self.max_txt_len = max_txt_len self.end_sym = end_sym @@ -87,7 +85,6 @@ class BindGPT4(BaseModel): imagebind_outputs = self.multimodal_encoder(inputs) llama_inputs = self.multimodal_joiner(imagebind_outputs) # NOTE: only accept image here - llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION]) return llama_inputs def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]: diff --git a/temp.py b/temp.py new file mode 100644 index 0000000..e69de29 diff --git a/train_configs/bindgpt4_stage1_pretrain.yaml b/train_configs/bindgpt4_stage1_pretrain.yaml index 5f8d0f6..1f9aad1 100644 --- a/train_configs/bindgpt4_stage1_pretrain.yaml +++ b/train_configs/bindgpt4_stage1_pretrain.yaml @@ -9,11 +9,11 @@ datasets: cc12m: vis_processor: train: - name: "blip2_image_train" + name: "imagebind_vision_train" image_size: 224 text_processor: train: - name: "blip_caption" + name: "imagebind_caption" sample_ratio: 115 # cc_sbu: # vis_processor: