mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
modify dataset & dataloader; modify the evaluation part
This commit is contained in:
parent
d1527dd924
commit
2d2d781469
31
demo.py
31
demo.py
@ -64,6 +64,7 @@ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config
|
|||||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||||
print('Initialization Finished')
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
|
||||||
# ========================================
|
# ========================================
|
||||||
# Gradio Setting
|
# Gradio Setting
|
||||||
# ========================================
|
# ========================================
|
||||||
@ -73,7 +74,10 @@ def gradio_reset(chat_state, img_list):
|
|||||||
chat_state.messages = []
|
chat_state.messages = []
|
||||||
if img_list is not None:
|
if img_list is not None:
|
||||||
img_list = []
|
img_list = []
|
||||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
|
||||||
|
interactive=False), gr.update(
|
||||||
|
value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
def upload_img(gr_img, text_input, chat_state):
|
def upload_img(gr_img, text_input, chat_state):
|
||||||
if gr_img is None:
|
if gr_img is None:
|
||||||
@ -81,7 +85,9 @@ def upload_img(gr_img, text_input, chat_state):
|
|||||||
chat_state = CONV_VISION.copy()
|
chat_state = CONV_VISION.copy()
|
||||||
img_list = []
|
img_list = []
|
||||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
|
||||||
|
value="Start Chatting", interactive=False), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
def gradio_ask(user_message, chatbot, chat_state):
|
def gradio_ask(user_message, chatbot, chat_state):
|
||||||
if len(user_message) == 0:
|
if len(user_message) == 0:
|
||||||
@ -101,17 +107,18 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
|||||||
chatbot[-1][1] = llm_message
|
chatbot[-1][1] = llm_message
|
||||||
return chatbot, chat_state, img_list
|
return chatbot, chat_state, img_list
|
||||||
|
|
||||||
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
|
||||||
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
|
||||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
description = """<h3>This is the demo of BindGPT-4. Upload your images and start chatting!</h3>"""
|
||||||
"""
|
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||||
|
# """
|
||||||
|
|
||||||
# TODO show examples below
|
# TODO show examples below
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown(title)
|
gr.Markdown(title)
|
||||||
gr.Markdown(description)
|
gr.Markdown(description)
|
||||||
gr.Markdown(article)
|
# gr.Markdown(article)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=0.5):
|
with gr.Column(scale=0.5):
|
||||||
@ -125,7 +132,7 @@ with gr.Blocks() as demo:
|
|||||||
value=1,
|
value=1,
|
||||||
step=1,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
label="beam search numbers)",
|
label="beam search numbers",
|
||||||
)
|
)
|
||||||
|
|
||||||
temperature = gr.Slider(
|
temperature = gr.Slider(
|
||||||
@ -140,14 +147,16 @@ with gr.Blocks() as demo:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
chat_state = gr.State()
|
chat_state = gr.State()
|
||||||
img_list = gr.State()
|
img_list = gr.State()
|
||||||
chatbot = gr.Chatbot(label='MiniGPT-4')
|
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||||
|
|
||||||
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||||
|
[image, text_input, upload_button, chat_state, img_list])
|
||||||
|
|
||||||
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
||||||
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
||||||
)
|
)
|
||||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
||||||
|
queue=False)
|
||||||
|
|
||||||
demo.launch(share=True, enable_queue=True)
|
demo.launch(share=True, enable_queue=True)
|
||||||
|
@ -1 +1,25 @@
|
|||||||
# TODO: Finish the eval config of ImageBindGPT4
|
model:
|
||||||
|
arch: bind_gpt4
|
||||||
|
model_type: pretrain_vicuna
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_qformer: False
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "###"
|
||||||
|
low_resource: False
|
||||||
|
prompt_path: "prompts/alignment.txt"
|
||||||
|
prompt_template: '###Human: {} ###Assistant: '
|
||||||
|
ckpt: '/path/to/pretrained/ckpt/'
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc_sbu_align:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_vision_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
@ -12,6 +12,8 @@ from typing import Dict
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
|
# logging.info = print
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
|
||||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Any
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
|
from imagebind.models.image_bind import ModalityType
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
@ -107,7 +103,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|||||||
|
|
||||||
|
|
||||||
CONV_VISION = Conversation(
|
CONV_VISION = Conversation(
|
||||||
system="Give the following image: <Img>ImageContent</Img>. "
|
system="Give the following image: <Vision>ImageContent</Vision>. "
|
||||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||||
roles=("Human", "Assistant"),
|
roles=("Human", "Assistant"),
|
||||||
messages=[],
|
messages=[],
|
||||||
@ -116,6 +112,7 @@ CONV_VISION = Conversation(
|
|||||||
sep="###",
|
sep="###",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
@ -128,14 +125,18 @@ class Chat:
|
|||||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
|
|
||||||
def ask(self, text, conv):
|
def ask(self, text, conv):
|
||||||
|
# NOTE: the hard code for postfix is removed.
|
||||||
|
# TODO: Need to be compatible with more modalities.
|
||||||
|
end_token = '</Vision>'
|
||||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||||
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
|
||||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||||
else:
|
else:
|
||||||
conv.append_message(conv.roles[0], text)
|
conv.append_message(conv.roles[0], text)
|
||||||
|
|
||||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
|
# Generate an answer written by LLaMA
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
embs = self.get_context_emb(conv, img_list)
|
embs = self.get_context_emb(conv, img_list)
|
||||||
|
|
||||||
@ -160,7 +161,7 @@ class Chat:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
output_token = outputs[0]
|
output_token = outputs[0]
|
||||||
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
||||||
output_token = output_token[1:]
|
output_token = output_token[1:]
|
||||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||||
output_token = output_token[1:]
|
output_token = output_token[1:]
|
||||||
@ -171,6 +172,7 @@ class Chat:
|
|||||||
return output_text, output_token.cpu().numpy()
|
return output_text, output_token.cpu().numpy()
|
||||||
|
|
||||||
def upload_img(self, image, conv, img_list):
|
def upload_img(self, image, conv, img_list):
|
||||||
|
# Upload Image, Encode Image and Create a new message from human.
|
||||||
if isinstance(image, str): # is a image path
|
if isinstance(image, str): # is a image path
|
||||||
raw_image = Image.open(image).convert('RGB')
|
raw_image = Image.open(image).convert('RGB')
|
||||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||||
@ -182,16 +184,18 @@ class Chat:
|
|||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
image = image.to(self.device)
|
image = image.to(self.device)
|
||||||
|
|
||||||
image_emb, _ = self.model.encode_img(image)
|
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
||||||
|
image_emb = all_embeddings[ModalityType.VISION]
|
||||||
img_list.append(image_emb)
|
img_list.append(image_emb)
|
||||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
conv.append_message(conv.roles[0], "<Vision><VisionHere></Vision>")
|
||||||
msg = "Received."
|
msg = "Received."
|
||||||
# self.conv.append_message(self.conv.roles[1], msg)
|
# self.conv.append_message(self.conv.roles[1], msg)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def get_context_emb(self, conv, img_list):
|
def get_context_emb(self, conv, img_list):
|
||||||
|
# Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
|
||||||
prompt = conv.get_prompt()
|
prompt = conv.get_prompt()
|
||||||
prompt_segs = prompt.split('<ImageHere>')
|
prompt_segs = prompt.split('<VisionHere>')
|
||||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
seg_tokens = [
|
seg_tokens = [
|
||||||
self.model.llama_tokenizer(
|
self.model.llama_tokenizer(
|
||||||
|
@ -47,7 +47,8 @@ class BindGPT4(BaseModel):
|
|||||||
|
|
||||||
print('Loading Q-Former and Adapter/Projector')
|
print('Loading Q-Former and Adapter/Projector')
|
||||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||||
vision_qformer_frozen=freeze_qformer
|
vision_qformer_frozen=freeze_qformer,
|
||||||
|
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||||
# vision_qformer_model=q_former_model,
|
# vision_qformer_model=q_former_model,
|
||||||
# vision_pre_dims=(1280, 1408)
|
# vision_pre_dims=(1280, 1408)
|
||||||
)
|
)
|
||||||
@ -66,9 +67,6 @@ class BindGPT4(BaseModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
print('Loading LLAMA Done')
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
# TODO: remove hard-coding
|
|
||||||
self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
|
|
||||||
|
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
|
|
||||||
@ -87,7 +85,6 @@ class BindGPT4(BaseModel):
|
|||||||
imagebind_outputs = self.multimodal_encoder(inputs)
|
imagebind_outputs = self.multimodal_encoder(inputs)
|
||||||
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
||||||
# NOTE: only accept image here
|
# NOTE: only accept image here
|
||||||
llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
|
|
||||||
return llama_inputs
|
return llama_inputs
|
||||||
|
|
||||||
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
||||||
|
@ -9,11 +9,11 @@ datasets:
|
|||||||
cc12m:
|
cc12m:
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "imagebind_vision_train"
|
||||||
image_size: 224
|
image_size: 224
|
||||||
text_processor:
|
text_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip_caption"
|
name: "imagebind_caption"
|
||||||
sample_ratio: 115
|
sample_ratio: 115
|
||||||
# cc_sbu:
|
# cc_sbu:
|
||||||
# vis_processor:
|
# vis_processor:
|
||||||
|
Loading…
Reference in New Issue
Block a user