From 2d2d781469a698e725da606d7e3d18fe66ae20cb Mon Sep 17 00:00:00 2001
From: unknown <913556700@qq.com>
Date: Wed, 24 May 2023 00:21:43 +0800
Subject: [PATCH] modify dataset & dataloader; modify the evaluation part
---
demo.py | 45 ++++++++++++---------
eval_configs/bindgpt4_eval.yaml | 26 +++++++++++-
minigpt4/common/config.py | 2 +
minigpt4/conversation/conversation.py | 36 +++++++++--------
minigpt4/models/bind_gpt4.py | 7 +---
temp.py | 0
train_configs/bindgpt4_stage1_pretrain.yaml | 4 +-
7 files changed, 78 insertions(+), 42 deletions(-)
create mode 100644 temp.py
diff --git a/demo.py b/demo.py
index b3659f1..d441571 100644
--- a/demo.py
+++ b/demo.py
@@ -28,8 +28,8 @@ def parse_args():
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
- "in xxx=yyy format will be merged into config file (deprecate), "
- "change to --cfg-options instead.",
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
)
args = parser.parse_args()
return args
@@ -64,6 +64,7 @@ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
+
# ========================================
# Gradio Setting
# ========================================
@@ -73,7 +74,10 @@ def gradio_reset(chat_state, img_list):
chat_state.messages = []
if img_list is not None:
img_list = []
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
+ interactive=False), gr.update(
+ value="Upload & Start Chat", interactive=True), chat_state, img_list
+
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
@@ -81,7 +85,9 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
+ value="Start Chatting", interactive=False), chat_state, img_list
+
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
@@ -101,33 +107,34 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
-title = """
Demo of MiniGPT-4
"""
-description = """This is the demo of MiniGPT-4. Upload your images and start chatting!
"""
-article = """


-"""
-#TODO show examples below
+title = """Demo of BindGPT-4
"""
+description = """This is the demo of BindGPT-4. Upload your images and start chatting!
"""
+# article = """


+# """
+
+# TODO show examples below
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
- gr.Markdown(article)
+ # gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
-
+
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
- label="beam search numbers)",
+ label="beam search numbers",
)
-
+
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
@@ -140,14 +147,16 @@ with gr.Blocks() as demo:
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
- chatbot = gr.Chatbot(label='MiniGPT-4')
+ chatbot = gr.Chatbot(label='BindGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
-
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
-
+
+ upload_button.click(upload_img, [image, text_input, chat_state],
+ [image, text_input, upload_button, chat_state, img_list])
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
+ queue=False)
demo.launch(share=True, enable_queue=True)
diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml
index 3a28b04..73d24e5 100644
--- a/eval_configs/bindgpt4_eval.yaml
+++ b/eval_configs/bindgpt4_eval.yaml
@@ -1 +1,25 @@
-# TODO: Finish the eval config of ImageBindGPT4
\ No newline at end of file
+model:
+ arch: bind_gpt4
+ model_type: pretrain_vicuna
+ freeze_vit: True
+ freeze_qformer: False
+ max_txt_len: 160
+ end_sym: "###"
+ low_resource: False
+ prompt_path: "prompts/alignment.txt"
+ prompt_template: '###Human: {} ###Assistant: '
+ ckpt: '/path/to/pretrained/ckpt/'
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "imagebind_vision_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "imagebind_caption"
+
+run:
+ task: image_text_pretrain
diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py
index e184b1f..39db8e0 100644
--- a/minigpt4/common/config.py
+++ b/minigpt4/common/config.py
@@ -12,6 +12,8 @@ from typing import Dict
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
+# logging.info = print
+
class Config:
def __init__(self, args):
diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
index 676d89f..86d8b80 100644
--- a/minigpt4/conversation/conversation.py
+++ b/minigpt4/conversation/conversation.py
@@ -1,16 +1,12 @@
-import argparse
-import time
-from PIL import Image
-
-import torch
-from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
-from transformers import StoppingCriteria, StoppingCriteriaList
-
import dataclasses
from enum import auto, Enum
-from typing import List, Tuple, Any
+from typing import List, Any
-from minigpt4.common.registry import registry
+import torch
+from PIL import Image
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+from imagebind.models.image_bind import ModalityType
class SeparatorStyle(Enum):
@@ -107,7 +103,7 @@ class StoppingCriteriaSub(StoppingCriteria):
CONV_VISION = Conversation(
- system="Give the following image:
ImageContent. "
+ system="Give the following image: ImageContent. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
@@ -116,6 +112,7 @@ CONV_VISION = Conversation(
sep="###",
)
+# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
class Chat:
@@ -128,14 +125,18 @@ class Chat:
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
+ # NOTE: the hard code for postfix is removed.
+ # TODO: Need to be compatible with more modalities.
+ end_token = ''
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
- and conv.messages[-1][1][-6:] == '': # last message is image.
+ and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
+ # Generate an answer written by LLaMA
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
@@ -160,7 +161,7 @@ class Chat:
temperature=temperature,
)
output_token = outputs[0]
- if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
+ if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
output_token = output_token[1:]
@@ -171,6 +172,7 @@ class Chat:
return output_text, output_token.cpu().numpy()
def upload_img(self, image, conv, img_list):
+ # Upload Image, Encode Image and Create a new message from human.
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@@ -182,16 +184,18 @@ class Chat:
image = image.unsqueeze(0)
image = image.to(self.device)
- image_emb, _ = self.model.encode_img(image)
+ all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
+ image_emb = all_embeddings[ModalityType.VISION]
img_list.append(image_emb)
- conv.append_message(conv.roles[0], "
")
+ conv.append_message(conv.roles[0], "")
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
+ # Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
prompt = conv.get_prompt()
- prompt_segs = prompt.split('')
+ prompt_segs = prompt.split('')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
index 5ec79d7..b400c7e 100644
--- a/minigpt4/models/bind_gpt4.py
+++ b/minigpt4/models/bind_gpt4.py
@@ -47,7 +47,8 @@ class BindGPT4(BaseModel):
print('Loading Q-Former and Adapter/Projector')
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
- vision_qformer_frozen=freeze_qformer
+ vision_qformer_frozen=freeze_qformer,
+ vision_post_dims=[768, self.llama_model.config.hidden_size]
# vision_qformer_model=q_former_model,
# vision_pre_dims=(1280, 1408)
)
@@ -66,9 +67,6 @@ class BindGPT4(BaseModel):
param.requires_grad = False
print('Loading LLAMA Done')
- # TODO: remove hard-coding
- self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
-
self.max_txt_len = max_txt_len
self.end_sym = end_sym
@@ -87,7 +85,6 @@ class BindGPT4(BaseModel):
imagebind_outputs = self.multimodal_encoder(inputs)
llama_inputs = self.multimodal_joiner(imagebind_outputs)
# NOTE: only accept image here
- llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
return llama_inputs
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
diff --git a/temp.py b/temp.py
new file mode 100644
index 0000000..e69de29
diff --git a/train_configs/bindgpt4_stage1_pretrain.yaml b/train_configs/bindgpt4_stage1_pretrain.yaml
index 5f8d0f6..1f9aad1 100644
--- a/train_configs/bindgpt4_stage1_pretrain.yaml
+++ b/train_configs/bindgpt4_stage1_pretrain.yaml
@@ -9,11 +9,11 @@ datasets:
cc12m:
vis_processor:
train:
- name: "blip2_image_train"
+ name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
- name: "blip_caption"
+ name: "imagebind_caption"
sample_ratio: 115
# cc_sbu:
# vis_processor: