mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
modify dataset & dataloader; modify the evaluation part
This commit is contained in:
parent
d1527dd924
commit
2d2d781469
45
demo.py
45
demo.py
@ -28,8 +28,8 @@ def parse_args():
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -64,6 +64,7 @@ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
print('Initialization Finished')
|
||||
|
||||
|
||||
# ========================================
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
@ -73,7 +74,10 @@ def gradio_reset(chat_state, img_list):
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
|
||||
interactive=False), gr.update(
|
||||
value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state):
|
||||
if gr_img is None:
|
||||
@ -81,7 +85,9 @@ def upload_img(gr_img, text_input, chat_state):
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
|
||||
value="Start Chatting", interactive=False), chat_state, img_list
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
if len(user_message) == 0:
|
||||
@ -101,33 +107,34 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
||||
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||
"""
|
||||
|
||||
#TODO show examples below
|
||||
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
|
||||
description = """<h3>This is the demo of BindGPT-4. Upload your images and start chatting!</h3>"""
|
||||
# article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||
# """
|
||||
|
||||
# TODO show examples below
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
gr.Markdown(article)
|
||||
# gr.Markdown(article)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
|
||||
num_beams = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=1,
|
||||
step=1,
|
||||
interactive=True,
|
||||
label="beam search numbers)",
|
||||
label="beam search numbers",
|
||||
)
|
||||
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0.1,
|
||||
maximum=2.0,
|
||||
@ -140,14 +147,16 @@ with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(label='MiniGPT-4')
|
||||
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
||||
|
||||
|
||||
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||
[image, text_input, upload_button, chat_state, img_list])
|
||||
|
||||
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
||||
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
||||
)
|
||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
||||
queue=False)
|
||||
|
||||
demo.launch(share=True, enable_queue=True)
|
||||
|
@ -1 +1,25 @@
|
||||
# TODO: Finish the eval config of ImageBindGPT4
|
||||
model:
|
||||
arch: bind_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: False
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: '/path/to/pretrained/ckpt/'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "imagebind_vision_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
|
@ -12,6 +12,8 @@ from typing import Dict
|
||||
from omegaconf import OmegaConf
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
# logging.info = print
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, args):
|
||||
|
@ -1,16 +1,12 @@
|
||||
import argparse
|
||||
import time
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
from typing import List, Any
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from imagebind.models.image_bind import ModalityType
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
@ -107,7 +103,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
system="Give the following image: <Vision>ImageContent</Vision>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
@ -116,6 +112,7 @@ CONV_VISION = Conversation(
|
||||
sep="###",
|
||||
)
|
||||
|
||||
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
||||
|
||||
|
||||
class Chat:
|
||||
@ -128,14 +125,18 @@ class Chat:
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
# NOTE: the hard code for postfix is removed.
|
||||
# TODO: Need to be compatible with more modalities.
|
||||
end_token = '</Vision>'
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
||||
and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
# Generate an answer written by LLaMA
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
@ -160,7 +161,7 @@ class Chat:
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
||||
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
@ -171,6 +172,7 @@ class Chat:
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
# Upload Image, Encode Image and Create a new message from human.
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
@ -182,16 +184,18 @@ class Chat:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
||||
image_emb = all_embeddings[ModalityType.VISION]
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||
conv.append_message(conv.roles[0], "<Vision><VisionHere></Vision>")
|
||||
msg = "Received."
|
||||
# self.conv.append_message(self.conv.roles[1], msg)
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
# Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
prompt_segs = prompt.split('<VisionHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
|
@ -47,7 +47,8 @@ class BindGPT4(BaseModel):
|
||||
|
||||
print('Loading Q-Former and Adapter/Projector')
|
||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||
vision_qformer_frozen=freeze_qformer
|
||||
vision_qformer_frozen=freeze_qformer,
|
||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||
# vision_qformer_model=q_former_model,
|
||||
# vision_pre_dims=(1280, 1408)
|
||||
)
|
||||
@ -66,9 +67,6 @@ class BindGPT4(BaseModel):
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
# TODO: remove hard-coding
|
||||
self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
@ -87,7 +85,6 @@ class BindGPT4(BaseModel):
|
||||
imagebind_outputs = self.multimodal_encoder(inputs)
|
||||
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
||||
# NOTE: only accept image here
|
||||
llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
|
||||
return llama_inputs
|
||||
|
||||
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
||||
|
@ -9,11 +9,11 @@ datasets:
|
||||
cc12m:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
name: "imagebind_vision_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 115
|
||||
# cc_sbu:
|
||||
# vis_processor:
|
||||
|
Loading…
Reference in New Issue
Block a user