mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
Evaluation (#3)
Evaluation (Version 1) (including Quantitative & Qualitative) WIP: Reconstruct the eval part & config part --------- Co-authored-by: unknown <913556700@qq.com> Co-authored-by: bingyikang <bingyikang@bytedance.com>
This commit is contained in:
parent
64472dedb1
commit
3efda2ac76
@ -87,7 +87,7 @@ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Lin
|
|||||||
|
|
||||||
### Launching Demo Locally
|
### Launching Demo Locally
|
||||||
|
|
||||||
Try out our demo [demo.py](demo.py) on your local machine by running
|
Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
|
||||||
|
|
||||||
```
|
```
|
||||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
model:
|
model:
|
||||||
arch: bind_gpt4
|
arch: bind_gpt4
|
||||||
model_type: pretrain_vicuna
|
model_type: pretrain_vicuna
|
||||||
freeze_vit: True
|
freeze_imagebind: True
|
||||||
freeze_qformer: False
|
freeze_qformer: False
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
low_resource: False
|
low_resource: False
|
||||||
prompt_path: "prompts/alignment.txt"
|
prompt_path: "prompts/alignment.txt"
|
||||||
prompt_template: '###Human: {} ###Assistant: '
|
prompt_template: '###Human: {} ###Assistant: '
|
||||||
ckpt: '/path/to/pretrained/ckpt/'
|
ckpt: 'minigpt4/output/minigpt4_stage1_pretrain/20230524192/checkpoint_0.pth'
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
cc_sbu_align:
|
cc12m: # Double check
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "imagebind_vision_eval"
|
name: "imagebind_vision_eval"
|
||||||
|
170
eval_scripts/conversation.py
Normal file
170
eval_scripts/conversation.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import dataclasses
|
||||||
|
from copy import deepcopy
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
|
from eval_scripts.eval_utils import load_image
|
||||||
|
from imagebind.models.image_bind import ModalityType
|
||||||
|
from minigpt4 import BaseProcessor
|
||||||
|
|
||||||
|
Roles = SimpleNamespace(
|
||||||
|
HUMAN="Human",
|
||||||
|
ASSISTANT="Assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
def __init__(self, role: str, content: Union[str, None]):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
"""A class that keeps all conversation history."""
|
||||||
|
system: str
|
||||||
|
messages: List[Message]
|
||||||
|
sep: str = "###"
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
ret = self.system + self.sep
|
||||||
|
for message in self.messages:
|
||||||
|
if message.content:
|
||||||
|
ret += message.role + ": " + message.content + self.sep
|
||||||
|
else:
|
||||||
|
ret += message.role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def append_message(self, role, content):
|
||||||
|
self.messages.append(Message(role, content))
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
messages=deepcopy(self.messages),
|
||||||
|
sep=self.sep)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"messages": [(msg.role, msg.content) for msg in self.messages],
|
||||||
|
"sep": self.sep
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteriaSub(StoppingCriteria):
|
||||||
|
def __init__(self, stops=[], encounters=1):
|
||||||
|
super().__init__()
|
||||||
|
self.stops = stops
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
|
for stop in self.stops:
|
||||||
|
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
CONV_VISION = Conversation(
|
||||||
|
system="Give the following image: <Vision>ImageContent</Vision>. "
|
||||||
|
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||||
|
messages=[],
|
||||||
|
sep="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
||||||
|
|
||||||
|
|
||||||
|
class Chat:
|
||||||
|
def __init__(self,
|
||||||
|
model: nn.Module,
|
||||||
|
processors: Dict[str, BaseProcessor],
|
||||||
|
device: str = 'cuda:0'
|
||||||
|
):
|
||||||
|
self.device = device
|
||||||
|
self.model = model
|
||||||
|
self.processors = processors
|
||||||
|
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||||
|
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||||
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
|
self.just_uploaded = False
|
||||||
|
|
||||||
|
def ask(self, text, conversation):
|
||||||
|
# NOTE: the hard code for postfix is removed.
|
||||||
|
# end_token = '</Vision>'
|
||||||
|
# if len(conversation.messages) > 0 and conversation.messages[-1].role == Roles.HUMAN \
|
||||||
|
# and conversation.messages[-1].content[-len(end_token):] == end_token:
|
||||||
|
if self.just_uploaded:
|
||||||
|
conversation.messages[-1].content = ' '.join([conversation.messages[-1].content, text])
|
||||||
|
self.just_uploaded = False
|
||||||
|
else:
|
||||||
|
conversation.append_message(Roles.HUMAN, text)
|
||||||
|
|
||||||
|
def answer(self, conversation, emb_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
|
# Generate an answer written by LLaMA
|
||||||
|
conversation.append_message(Roles.ASSISTANT, None)
|
||||||
|
embs = self.get_context_emb(conversation, emb_list)
|
||||||
|
|
||||||
|
current_max_len = embs.shape[1] + max_new_tokens
|
||||||
|
if current_max_len - max_length > 0:
|
||||||
|
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||||
|
'The model will not see the contexts outside the range.')
|
||||||
|
begin_idx = max(0, current_max_len - max_length)
|
||||||
|
|
||||||
|
embs = embs[:, begin_idx:]
|
||||||
|
|
||||||
|
outputs = self.model.llama_model.generate(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
stopping_criteria=self.stopping_criteria,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=True,
|
||||||
|
min_length=min_length,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
output_token = outputs[0]
|
||||||
|
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
||||||
|
output_token = output_token[1:]
|
||||||
|
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||||
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
conversation.messages[-1].content = output_text
|
||||||
|
return output_text, output_token.cpu().numpy()
|
||||||
|
|
||||||
|
def upload_img(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]):
|
||||||
|
# Upload Image, Encode Image and Create a new message from human.
|
||||||
|
image = load_image(image, self.processors[ModalityType.VISION]).to(self.device)
|
||||||
|
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
||||||
|
image_emb = all_embeddings[ModalityType.VISION]
|
||||||
|
emb_list.append(image_emb)
|
||||||
|
conversation.append_message(Roles.HUMAN, "<Vision><ModalityHere></Vision>")
|
||||||
|
self.just_uploaded = True
|
||||||
|
|
||||||
|
def get_context_emb(self, conversation: Conversation, emb_list: List[Tensor]):
|
||||||
|
# Insert the embeddings into the prompts and queries.
|
||||||
|
# NOTE: Assume the placeholders have been aligned to the embeddings!
|
||||||
|
prompt = conversation.get_prompt()
|
||||||
|
prompt_segs = prompt.split('<ModalityHere>')
|
||||||
|
assert len(prompt_segs) == len(emb_list) + 1, "Unmatched numbers of placeholders and embeddings."
|
||||||
|
seg_tokens = [
|
||||||
|
self.model.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||||
|
# only add bos to the first seg
|
||||||
|
for i, seg in enumerate(prompt_segs)
|
||||||
|
]
|
||||||
|
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], emb_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
return mixed_embs
|
15
eval_scripts/eval_utils.py
Normal file
15
eval_scripts/eval_utils.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image, image_processor):
|
||||||
|
if isinstance(image, str): # is a image path
|
||||||
|
raw_image = Image.open(image).convert('RGB')
|
||||||
|
image = image_processor(raw_image).unsqueeze(0)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
raw_image = image
|
||||||
|
image = image_processor(raw_image).unsqueeze(0)
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
return image
|
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,18 +9,16 @@ import gradio as gr
|
|||||||
from minigpt4.common.config import Config
|
from minigpt4.common.config import Config
|
||||||
from minigpt4.common.dist_utils import get_rank
|
from minigpt4.common.dist_utils import get_rank
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
from eval_scripts.conversation import Chat, CONV_VISION
|
||||||
|
# NOTE&TODO: put this before minigpt4 import will cause circular import
|
||||||
|
# possibly because `imagebind` imports `minigpt4` and `minigpt4` also imports `imagebind`
|
||||||
|
from imagebind.models.image_bind import ModalityType
|
||||||
|
|
||||||
# imports modules for registration
|
# imports modules for registration
|
||||||
from minigpt4.datasets.builders import *
|
|
||||||
from minigpt4.models import *
|
|
||||||
from minigpt4.processors import *
|
|
||||||
from minigpt4.runners import *
|
|
||||||
from minigpt4.tasks import *
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Demo")
|
parser = argparse.ArgumentParser(description="Qualitative")
|
||||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -59,9 +56,11 @@ model_config.device_8bit = args.gpu_id
|
|||||||
model_cls = registry.get_model_class(model_config.arch)
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||||
|
|
||||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
# TODO: Fix hard-coding `cc12m`
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
|
||||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
processors = {ModalityType.VISION: vis_processor}
|
||||||
|
chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
|
||||||
print('Initialization Finished')
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
|
||||||
@ -69,24 +68,27 @@ print('Initialization Finished')
|
|||||||
# Gradio Setting
|
# Gradio Setting
|
||||||
# ========================================
|
# ========================================
|
||||||
|
|
||||||
def gradio_reset(chat_state, img_list):
|
def gradio_reset(chat_state, emb_list):
|
||||||
if chat_state is not None:
|
if chat_state is not None:
|
||||||
chat_state.messages = []
|
chat_state.messages = []
|
||||||
if img_list is not None:
|
if emb_list is not None:
|
||||||
img_list = []
|
emb_list = []
|
||||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
|
return None, gr.update(value=None, interactive=True), \
|
||||||
interactive=False), gr.update(
|
gr.update(placeholder='Please upload your image first', interactive=False), \
|
||||||
value="Upload & Start Chat", interactive=True), chat_state, img_list
|
gr.update(value="Upload & Start Chat", interactive=True), \
|
||||||
|
chat_state, emb_list
|
||||||
|
|
||||||
|
|
||||||
def upload_img(gr_img, text_input, chat_state):
|
def upload_img(gr_img, text_input, chat_state):
|
||||||
if gr_img is None:
|
if gr_img is None:
|
||||||
return None, None, gr.update(interactive=True), chat_state, None
|
return None, None, gr.update(interactive=True), chat_state, None
|
||||||
chat_state = CONV_VISION.copy()
|
chat_state = CONV_VISION.copy()
|
||||||
img_list = []
|
emb_list = []
|
||||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
chat.upload_img(gr_img, chat_state, emb_list)
|
||||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
|
return gr.update(interactive=False), \
|
||||||
value="Start Chatting", interactive=False), chat_state, img_list
|
gr.update(interactive=True, placeholder='Type and press Enter'), \
|
||||||
|
gr.update(value="Start Chatting", interactive=False), \
|
||||||
|
chat_state, emb_list
|
||||||
|
|
||||||
|
|
||||||
def gradio_ask(user_message, chatbot, chat_state):
|
def gradio_ask(user_message, chatbot, chat_state):
|
||||||
@ -97,15 +99,15 @@ def gradio_ask(user_message, chatbot, chat_state):
|
|||||||
return '', chatbot, chat_state
|
return '', chatbot, chat_state
|
||||||
|
|
||||||
|
|
||||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature):
|
||||||
llm_message = chat.answer(conv=chat_state,
|
llm_message = chat.answer(conversation=chat_state,
|
||||||
img_list=img_list,
|
emb_list=emb_list,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_new_tokens=300,
|
max_new_tokens=300,
|
||||||
max_length=2000)[0]
|
max_length=2000)[0]
|
||||||
chatbot[-1][1] = llm_message
|
chatbot[-1][1] = llm_message
|
||||||
return chatbot, chat_state, img_list
|
return chatbot, chat_state, emb_list
|
||||||
|
|
||||||
|
|
||||||
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
|
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
|
||||||
@ -146,17 +148,17 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
chat_state = gr.State()
|
chat_state = gr.State()
|
||||||
img_list = gr.State()
|
emb_list = gr.State()
|
||||||
chatbot = gr.Chatbot(label='BindGPT-4')
|
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||||
|
|
||||||
upload_button.click(upload_img, [image, text_input, chat_state],
|
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||||
[image, text_input, upload_button, chat_state, img_list])
|
[image, text_input, upload_button, chat_state, emb_list])
|
||||||
|
|
||||||
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
||||||
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
gradio_answer, [chatbot, chat_state, emb_list, num_beams, temperature], [chatbot, chat_state, emb_list]
|
||||||
)
|
)
|
||||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
clear.click(gradio_reset, [chat_state, emb_list], [chatbot, image, text_input, upload_button, chat_state, emb_list],
|
||||||
queue=False)
|
queue=False)
|
||||||
|
|
||||||
demo.launch(share=True, enable_queue=True)
|
demo.launch(share=True, enable_queue=True)
|
82
eval_scripts/quantitative_eval.py
Normal file
82
eval_scripts/quantitative_eval.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import shortuuid
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
# TODO: check the circular import problem in `eval_scripts.conversation`
|
||||||
|
from eval_scripts.conversation import Chat, CONV_VISION
|
||||||
|
from imagebind.models.image_bind import ModalityType
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Quantitative")
|
||||||
|
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||||
|
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--question-path", required=True, help="path to question file.")
|
||||||
|
parser.add_argument("--answer-path", required=True, help="path to answer result file.")
|
||||||
|
parser.add_argument("--image-folder", required=True, help="path to the image queries (COCO 2014 val).")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Model Initialization
|
||||||
|
# ========================================
|
||||||
|
print('Initializing Chat')
|
||||||
|
args = parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_config.device_8bit = args.gpu_id
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||||
|
|
||||||
|
# TODO: fix hard-coding `cc12m`
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
processors = {ModalityType.VISION: vis_processor}
|
||||||
|
chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
|
||||||
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Prompt Setting
|
||||||
|
# ========================================
|
||||||
|
prompt = "Give the following image: <Vision>ImageContent</Vision>. " \
|
||||||
|
"You will be able to see the image once I provide it to you. Please answer my question."
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Question Loading
|
||||||
|
# ========================================
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
questions = [json.loads(q) for q in open(args.question_path, "r")]
|
||||||
|
answer_file = open(args.answer_path, "w")
|
||||||
|
for i, line in enumerate(tqdm(questions)):
|
||||||
|
idx = line["question_id"]
|
||||||
|
image_file = os.path.join(args.image_folder, "COCO_val2014_" + line["image"])
|
||||||
|
question = line["text"]
|
||||||
|
state = CONV_VISION.copy()
|
||||||
|
emb_list = []
|
||||||
|
chat.upload_img(image_file, state, emb_list)
|
||||||
|
chat.ask(question, state)
|
||||||
|
answer, _ = chat.answer(state, emb_list)
|
||||||
|
ans_id = shortuuid.uuid()
|
||||||
|
answer_file.write(json.dumps({"question_id": idx,
|
||||||
|
"prompt": question,
|
||||||
|
"text": answer,
|
||||||
|
"answer_id": ans_id,
|
||||||
|
"model_id": model_config.arch,
|
||||||
|
"metadata": {}}) + "\n")
|
||||||
|
answer_file.flush()
|
||||||
|
answer_file.close()
|
||||||
|
|
@ -1,211 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
from enum import auto, Enum
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
||||||
|
|
||||||
from imagebind.models.image_bind import ModalityType
|
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
|
||||||
"""Different separator style."""
|
|
||||||
SINGLE = auto()
|
|
||||||
TWO = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Conversation:
|
|
||||||
"""A class that keeps all conversation history."""
|
|
||||||
system: str
|
|
||||||
roles: List[str]
|
|
||||||
messages: List[List[str]]
|
|
||||||
offset: int
|
|
||||||
# system_img: List[Image.Image] = []
|
|
||||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
||||||
sep: str = "###"
|
|
||||||
sep2: str = None
|
|
||||||
|
|
||||||
skip_next: bool = False
|
|
||||||
conv_id: Any = None
|
|
||||||
|
|
||||||
def get_prompt(self):
|
|
||||||
if self.sep_style == SeparatorStyle.SINGLE:
|
|
||||||
ret = self.system + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + self.sep
|
|
||||||
else:
|
|
||||||
ret += role + ":"
|
|
||||||
return ret
|
|
||||||
elif self.sep_style == SeparatorStyle.TWO:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
ret = self.system + seps[0]
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + seps[i % 2]
|
|
||||||
else:
|
|
||||||
ret += role + ":"
|
|
||||||
return ret
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
||||||
|
|
||||||
def append_message(self, role, message):
|
|
||||||
self.messages.append([role, message])
|
|
||||||
|
|
||||||
def to_gradio_chatbot(self):
|
|
||||||
ret = []
|
|
||||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
|
||||||
if i % 2 == 0:
|
|
||||||
ret.append([msg, None])
|
|
||||||
else:
|
|
||||||
ret[-1][-1] = msg
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return Conversation(
|
|
||||||
system=self.system,
|
|
||||||
# system_img=self.system_img,
|
|
||||||
roles=self.roles,
|
|
||||||
messages=[[x, y] for x, y in self.messages],
|
|
||||||
offset=self.offset,
|
|
||||||
sep_style=self.sep_style,
|
|
||||||
sep=self.sep,
|
|
||||||
sep2=self.sep2,
|
|
||||||
conv_id=self.conv_id)
|
|
||||||
|
|
||||||
def dict(self):
|
|
||||||
return {
|
|
||||||
"system": self.system,
|
|
||||||
# "system_img": self.system_img,
|
|
||||||
"roles": self.roles,
|
|
||||||
"messages": self.messages,
|
|
||||||
"offset": self.offset,
|
|
||||||
"sep": self.sep,
|
|
||||||
"sep2": self.sep2,
|
|
||||||
"conv_id": self.conv_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaSub(StoppingCriteria):
|
|
||||||
|
|
||||||
def __init__(self, stops=[], encounters=1):
|
|
||||||
super().__init__()
|
|
||||||
self.stops = stops
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
||||||
for stop in self.stops:
|
|
||||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
CONV_VISION = Conversation(
|
|
||||||
system="Give the following image: <Vision>ImageContent</Vision>. "
|
|
||||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
|
||||||
roles=("Human", "Assistant"),
|
|
||||||
messages=[],
|
|
||||||
offset=2,
|
|
||||||
sep_style=SeparatorStyle.SINGLE,
|
|
||||||
sep="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
|
||||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
|
||||||
self.device = device
|
|
||||||
self.model = model
|
|
||||||
self.vis_processor = vis_processor
|
|
||||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
|
||||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
|
||||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
|
||||||
|
|
||||||
def ask(self, text, conv):
|
|
||||||
# NOTE: the hard code for postfix is removed.
|
|
||||||
# TODO: Need to be compatible with more modalities.
|
|
||||||
end_token = '</Vision>'
|
|
||||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
|
||||||
and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
|
|
||||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
|
||||||
else:
|
|
||||||
conv.append_message(conv.roles[0], text)
|
|
||||||
|
|
||||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
|
||||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
|
||||||
# Generate an answer written by LLaMA
|
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
embs = self.get_context_emb(conv, img_list)
|
|
||||||
|
|
||||||
current_max_len = embs.shape[1] + max_new_tokens
|
|
||||||
if current_max_len - max_length > 0:
|
|
||||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
|
||||||
'The model will not see the contexts outside the range.')
|
|
||||||
begin_idx = max(0, current_max_len - max_length)
|
|
||||||
|
|
||||||
embs = embs[:, begin_idx:]
|
|
||||||
|
|
||||||
outputs = self.model.llama_model.generate(
|
|
||||||
inputs_embeds=embs,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
stopping_criteria=self.stopping_criteria,
|
|
||||||
num_beams=num_beams,
|
|
||||||
do_sample=True,
|
|
||||||
min_length=min_length,
|
|
||||||
top_p=top_p,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
length_penalty=length_penalty,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
output_token = outputs[0]
|
|
||||||
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
|
||||||
output_token = output_token[1:]
|
|
||||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
|
||||||
output_token = output_token[1:]
|
|
||||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
|
||||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
|
||||||
output_text = output_text.split('Assistant:')[-1].strip()
|
|
||||||
conv.messages[-1][1] = output_text
|
|
||||||
return output_text, output_token.cpu().numpy()
|
|
||||||
|
|
||||||
def upload_img(self, image, conv, img_list):
|
|
||||||
# Upload Image, Encode Image and Create a new message from human.
|
|
||||||
if isinstance(image, str): # is a image path
|
|
||||||
raw_image = Image.open(image).convert('RGB')
|
|
||||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
|
||||||
elif isinstance(image, Image.Image):
|
|
||||||
raw_image = image
|
|
||||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
|
||||||
elif isinstance(image, torch.Tensor):
|
|
||||||
if len(image.shape) == 3:
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
image = image.to(self.device)
|
|
||||||
|
|
||||||
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
|
||||||
image_emb = all_embeddings[ModalityType.VISION]
|
|
||||||
img_list.append(image_emb)
|
|
||||||
conv.append_message(conv.roles[0], "<Vision><VisionHere></Vision>")
|
|
||||||
msg = "Received."
|
|
||||||
# self.conv.append_message(self.conv.roles[1], msg)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def get_context_emb(self, conv, img_list):
|
|
||||||
# Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
|
|
||||||
prompt = conv.get_prompt()
|
|
||||||
prompt_segs = prompt.split('<VisionHere>')
|
|
||||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
|
||||||
seg_tokens = [
|
|
||||||
self.model.llama_tokenizer(
|
|
||||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
|
||||||
# only add bos to the first seg
|
|
||||||
for i, seg in enumerate(prompt_segs)
|
|
||||||
]
|
|
||||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
|
||||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
|
||||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
||||||
return mixed_embs
|
|
||||||
|
|
||||||
|
|
@ -45,15 +45,6 @@ class BindGPT4(BaseModel):
|
|||||||
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
|
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
|
||||||
print('Loading ImageBind Done')
|
print('Loading ImageBind Done')
|
||||||
|
|
||||||
print('Loading Q-Former and Adapter/Projector')
|
|
||||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
|
||||||
vision_qformer_frozen=freeze_qformer,
|
|
||||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
|
||||||
# vision_qformer_model=q_former_model,
|
|
||||||
# vision_pre_dims=(1280, 1408)
|
|
||||||
)
|
|
||||||
print('Loading Q-Former and Adapter/Projector Done')
|
|
||||||
|
|
||||||
print('Loading LLAMA')
|
print('Loading LLAMA')
|
||||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||||
@ -67,6 +58,15 @@ class BindGPT4(BaseModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
print('Loading LLAMA Done')
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
|
print('Loading Q-Former and Adapter/Projector')
|
||||||
|
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||||
|
vision_qformer_frozen=freeze_qformer,
|
||||||
|
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||||
|
# vision_qformer_model=q_former_model,
|
||||||
|
# vision_pre_dims=(1280, 1408)
|
||||||
|
)
|
||||||
|
print('Loading Q-Former and Adapter/Projector Done')
|
||||||
|
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class BindGPT4(BaseModel):
|
|||||||
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
||||||
if prompt:
|
if prompt:
|
||||||
batch_size = input_embeds.shape[0]
|
batch_size = input_embeds.shape[0]
|
||||||
p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
|
p_before, p_after = prompt.split('<ModalityHere>')
|
||||||
p_before_tokens = self.llama_tokenizer(
|
p_before_tokens = self.llama_tokenizer(
|
||||||
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
||||||
p_after_tokens = self.llama_tokenizer(
|
p_after_tokens = self.llama_tokenizer(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
<Vision><VisionHere></Vision> Describe this image in detail.
|
<Vision><ModalityHere></Vision> Describe this image in detail.
|
||||||
<Vision><VisionHere></Vision> Take a look at this image and describe what you notice.
|
<Vision><ModalityHere></Vision> Take a look at this image and describe what you notice.
|
||||||
<Vision><VisionHere></Vision> Please provide a detailed description of the picture.
|
<Vision><ModalityHere></Vision> Please provide a detailed description of the picture.
|
||||||
<Vision><VisionHere></Vision> Could you describe the contents of this image for me?
|
<Vision><ModalityHere></Vision> Could you describe the contents of this image for me?
|
@ -1,7 +1,7 @@
|
|||||||
model:
|
model:
|
||||||
arch: bind_gpt4
|
arch: bind_gpt4
|
||||||
model_type: pretrain_vicuna
|
model_type: pretrain_vicuna
|
||||||
freeze_vit: True
|
freeze_imagebind: True
|
||||||
freeze_qformer: False
|
freeze_qformer: False
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user