modify dataset & dataloader; modify the evaluation part

This commit is contained in:
unknown 2023-05-24 00:21:43 +08:00
parent d1527dd924
commit 2d2d781469
7 changed files with 78 additions and 42 deletions

45
demo.py
View File

@ -28,8 +28,8 @@ def parse_args():
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
@ -64,6 +64,7 @@ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
# ========================================
# Gradio Setting
# ========================================
@ -73,7 +74,10 @@ def gradio_reset(chat_state, img_list):
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
interactive=False), gr.update(
value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
@ -81,7 +85,9 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
@ -101,33 +107,34 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
"""
#TODO show examples below
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
description = """<h3>This is the demo of BindGPT-4. Upload your images and start chatting!</h3>"""
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
# """
# TODO show examples below
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
# gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
label="beam search numbers",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
@ -140,14 +147,16 @@ with gr.Blocks() as demo:
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')
chatbot = gr.Chatbot(label='BindGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
upload_button.click(upload_img, [image, text_input, chat_state],
[image, text_input, upload_button, chat_state, img_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
queue=False)
demo.launch(share=True, enable_queue=True)

View File

@ -1 +1,25 @@
# TODO: Finish the eval config of ImageBindGPT4
model:
arch: bind_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: False
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "imagebind_vision_eval"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
run:
task: image_text_pretrain

View File

@ -12,6 +12,8 @@ from typing import Dict
from omegaconf import OmegaConf
from minigpt4.common.registry import registry
# logging.info = print
class Config:
def __init__(self, args):

View File

@ -1,16 +1,12 @@
import argparse
import time
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from typing import List, Any
from minigpt4.common.registry import registry
import torch
from PIL import Image
from transformers import StoppingCriteria, StoppingCriteriaList
from imagebind.models.image_bind import ModalityType
class SeparatorStyle(Enum):
@ -107,7 +103,7 @@ class StoppingCriteriaSub(StoppingCriteria):
CONV_VISION = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
system="Give the following image: <Vision>ImageContent</Vision>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
@ -116,6 +112,7 @@ CONV_VISION = Conversation(
sep="###",
)
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
class Chat:
@ -128,14 +125,18 @@ class Chat:
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
# NOTE: the hard code for postfix is removed.
# TODO: Need to be compatible with more modalities.
end_token = '</Vision>'
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
# Generate an answer written by LLaMA
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
@ -160,7 +161,7 @@ class Chat:
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
@ -171,6 +172,7 @@ class Chat:
return output_text, output_token.cpu().numpy()
def upload_img(self, image, conv, img_list):
# Upload Image, Encode Image and Create a new message from human.
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@ -182,16 +184,18 @@ class Chat:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
image_emb = all_embeddings[ModalityType.VISION]
img_list.append(image_emb)
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
conv.append_message(conv.roles[0], "<Vision><VisionHere></Vision>")
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
# Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
prompt_segs = prompt.split('<VisionHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(

View File

@ -47,7 +47,8 @@ class BindGPT4(BaseModel):
print('Loading Q-Former and Adapter/Projector')
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
vision_qformer_frozen=freeze_qformer
vision_qformer_frozen=freeze_qformer,
vision_post_dims=[768, self.llama_model.config.hidden_size]
# vision_qformer_model=q_former_model,
# vision_pre_dims=(1280, 1408)
)
@ -66,9 +67,6 @@ class BindGPT4(BaseModel):
param.requires_grad = False
print('Loading LLAMA Done')
# TODO: remove hard-coding
self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
@ -87,7 +85,6 @@ class BindGPT4(BaseModel):
imagebind_outputs = self.multimodal_encoder(inputs)
llama_inputs = self.multimodal_joiner(imagebind_outputs)
# NOTE: only accept image here
llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
return llama_inputs
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:

0
temp.py Normal file
View File

View File

@ -9,11 +9,11 @@ datasets:
cc12m:
vis_processor:
train:
name: "blip2_image_train"
name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
name: "imagebind_caption"
sample_ratio: 115
# cc_sbu:
# vis_processor: