mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
Evaluation (#3)
Evaluation (Version 1) (including Quantitative & Qualitative) WIP: Reconstruct the eval part & config part --------- Co-authored-by: unknown <913556700@qq.com> Co-authored-by: bingyikang <bingyikang@bytedance.com>
This commit is contained in:
parent
64472dedb1
commit
3efda2ac76
@ -87,7 +87,7 @@ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Lin
|
||||
|
||||
### Launching Demo Locally
|
||||
|
||||
Try out our demo [demo.py](demo.py) on your local machine by running
|
||||
Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
|
||||
|
||||
```
|
||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||
|
@ -1,18 +1,18 @@
|
||||
model:
|
||||
arch: bind_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_imagebind: True
|
||||
freeze_qformer: False
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: '/path/to/pretrained/ckpt/'
|
||||
ckpt: 'minigpt4/output/minigpt4_stage1_pretrain/20230524192/checkpoint_0.pth'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
cc12m: # Double check
|
||||
vis_processor:
|
||||
train:
|
||||
name: "imagebind_vision_eval"
|
||||
|
170
eval_scripts/conversation.py
Normal file
170
eval_scripts/conversation.py
Normal file
@ -0,0 +1,170 @@
|
||||
import dataclasses
|
||||
from copy import deepcopy
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Union, Dict
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn, Tensor
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from eval_scripts.eval_utils import load_image
|
||||
from imagebind.models.image_bind import ModalityType
|
||||
from minigpt4 import BaseProcessor
|
||||
|
||||
Roles = SimpleNamespace(
|
||||
HUMAN="Human",
|
||||
ASSISTANT="Assistant"
|
||||
)
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, None]):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
system: str
|
||||
messages: List[Message]
|
||||
sep: str = "###"
|
||||
|
||||
def get_prompt(self):
|
||||
ret = self.system + self.sep
|
||||
for message in self.messages:
|
||||
if message.content:
|
||||
ret += message.role + ": " + message.content + self.sep
|
||||
else:
|
||||
ret += message.role + ":"
|
||||
return ret
|
||||
|
||||
def append_message(self, role, content):
|
||||
self.messages.append(Message(role, content))
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
messages=deepcopy(self.messages),
|
||||
sep=self.sep)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"messages": [(msg.role, msg.content) for msg in self.messages],
|
||||
"sep": self.sep
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Vision>ImageContent</Vision>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
messages=[],
|
||||
sep="###",
|
||||
)
|
||||
|
||||
|
||||
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
processors: Dict[str, BaseProcessor],
|
||||
device: str = 'cuda:0'
|
||||
):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.processors = processors
|
||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
self.just_uploaded = False
|
||||
|
||||
def ask(self, text, conversation):
|
||||
# NOTE: the hard code for postfix is removed.
|
||||
# end_token = '</Vision>'
|
||||
# if len(conversation.messages) > 0 and conversation.messages[-1].role == Roles.HUMAN \
|
||||
# and conversation.messages[-1].content[-len(end_token):] == end_token:
|
||||
if self.just_uploaded:
|
||||
conversation.messages[-1].content = ' '.join([conversation.messages[-1].content, text])
|
||||
self.just_uploaded = False
|
||||
else:
|
||||
conversation.append_message(Roles.HUMAN, text)
|
||||
|
||||
def answer(self, conversation, emb_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
# Generate an answer written by LLaMA
|
||||
conversation.append_message(Roles.ASSISTANT, None)
|
||||
embs = self.get_context_emb(conversation, emb_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
do_sample=True,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
conversation.messages[-1].content = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]):
|
||||
# Upload Image, Encode Image and Create a new message from human.
|
||||
image = load_image(image, self.processors[ModalityType.VISION]).to(self.device)
|
||||
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
||||
image_emb = all_embeddings[ModalityType.VISION]
|
||||
emb_list.append(image_emb)
|
||||
conversation.append_message(Roles.HUMAN, "<Vision><ModalityHere></Vision>")
|
||||
self.just_uploaded = True
|
||||
|
||||
def get_context_emb(self, conversation: Conversation, emb_list: List[Tensor]):
|
||||
# Insert the embeddings into the prompts and queries.
|
||||
# NOTE: Assume the placeholders have been aligned to the embeddings!
|
||||
prompt = conversation.get_prompt()
|
||||
prompt_segs = prompt.split('<ModalityHere>')
|
||||
assert len(prompt_segs) == len(emb_list) + 1, "Unmatched numbers of placeholders and embeddings."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], emb_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
15
eval_scripts/eval_utils.py
Normal file
15
eval_scripts/eval_utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(image, image_processor):
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = image_processor(raw_image).unsqueeze(0)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = image_processor(raw_image).unsqueeze(0)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
return image
|
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
@ -10,18 +9,16 @@ import gradio as gr
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||
from eval_scripts.conversation import Chat, CONV_VISION
|
||||
# NOTE&TODO: put this before minigpt4 import will cause circular import
|
||||
# possibly because `imagebind` imports `minigpt4` and `minigpt4` also imports `imagebind`
|
||||
from imagebind.models.image_bind import ModalityType
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser = argparse.ArgumentParser(description="Qualitative")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||
parser.add_argument(
|
||||
@ -59,9 +56,11 @@ model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
# TODO: Fix hard-coding `cc12m`
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
processors = {ModalityType.VISION: vis_processor}
|
||||
chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
|
||||
print('Initialization Finished')
|
||||
|
||||
|
||||
@ -69,24 +68,27 @@ print('Initialization Finished')
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
def gradio_reset(chat_state, emb_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
|
||||
interactive=False), gr.update(
|
||||
value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||
if emb_list is not None:
|
||||
emb_list = []
|
||||
return None, gr.update(value=None, interactive=True), \
|
||||
gr.update(placeholder='Please upload your image first', interactive=False), \
|
||||
gr.update(value="Upload & Start Chat", interactive=True), \
|
||||
chat_state, emb_list
|
||||
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state):
|
||||
if gr_img is None:
|
||||
return None, None, gr.update(interactive=True), chat_state, None
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
|
||||
value="Start Chatting", interactive=False), chat_state, img_list
|
||||
emb_list = []
|
||||
chat.upload_img(gr_img, chat_state, emb_list)
|
||||
return gr.update(interactive=False), \
|
||||
gr.update(interactive=True, placeholder='Type and press Enter'), \
|
||||
gr.update(value="Start Chatting", interactive=False), \
|
||||
chat_state, emb_list
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
@ -97,15 +99,15 @@ def gradio_ask(user_message, chatbot, chat_state):
|
||||
return '', chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
llm_message = chat.answer(conv=chat_state,
|
||||
img_list=img_list,
|
||||
def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature):
|
||||
llm_message = chat.answer(conversation=chat_state,
|
||||
emb_list=emb_list,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature,
|
||||
max_new_tokens=300,
|
||||
max_length=2000)[0]
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
return chatbot, chat_state, emb_list
|
||||
|
||||
|
||||
title = """<h1 align="center">Demo of BindGPT-4</h1>"""
|
||||
@ -146,17 +148,17 @@ with gr.Blocks() as demo:
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
emb_list = gr.State()
|
||||
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||
[image, text_input, upload_button, chat_state, img_list])
|
||||
[image, text_input, upload_button, chat_state, emb_list])
|
||||
|
||||
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
||||
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
||||
gradio_answer, [chatbot, chat_state, emb_list, num_beams, temperature], [chatbot, chat_state, emb_list]
|
||||
)
|
||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
||||
clear.click(gradio_reset, [chat_state, emb_list], [chatbot, image, text_input, upload_button, chat_state, emb_list],
|
||||
queue=False)
|
||||
|
||||
demo.launch(share=True, enable_queue=True)
|
82
eval_scripts/quantitative_eval.py
Normal file
82
eval_scripts/quantitative_eval.py
Normal file
@ -0,0 +1,82 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import shortuuid
|
||||
from tqdm import tqdm
|
||||
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.registry import registry
|
||||
# TODO: check the circular import problem in `eval_scripts.conversation`
|
||||
from eval_scripts.conversation import Chat, CONV_VISION
|
||||
from imagebind.models.image_bind import ModalityType
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Quantitative")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
parser.add_argument("--question-path", required=True, help="path to question file.")
|
||||
parser.add_argument("--answer-path", required=True, help="path to answer result file.")
|
||||
parser.add_argument("--image-folder", required=True, help="path to the image queries (COCO 2014 val).")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# ========================================
|
||||
# Model Initialization
|
||||
# ========================================
|
||||
print('Initializing Chat')
|
||||
args = parse_args()
|
||||
cfg = Config(args)
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||
|
||||
# TODO: fix hard-coding `cc12m`
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
processors = {ModalityType.VISION: vis_processor}
|
||||
chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
|
||||
print('Initialization Finished')
|
||||
|
||||
# ========================================
|
||||
# Prompt Setting
|
||||
# ========================================
|
||||
prompt = "Give the following image: <Vision>ImageContent</Vision>. " \
|
||||
"You will be able to see the image once I provide it to you. Please answer my question."
|
||||
|
||||
# ========================================
|
||||
# Question Loading
|
||||
# ========================================
|
||||
import pdb; pdb.set_trace()
|
||||
questions = [json.loads(q) for q in open(args.question_path, "r")]
|
||||
answer_file = open(args.answer_path, "w")
|
||||
for i, line in enumerate(tqdm(questions)):
|
||||
idx = line["question_id"]
|
||||
image_file = os.path.join(args.image_folder, "COCO_val2014_" + line["image"])
|
||||
question = line["text"]
|
||||
state = CONV_VISION.copy()
|
||||
emb_list = []
|
||||
chat.upload_img(image_file, state, emb_list)
|
||||
chat.ask(question, state)
|
||||
answer, _ = chat.answer(state, emb_list)
|
||||
ans_id = shortuuid.uuid()
|
||||
answer_file.write(json.dumps({"question_id": idx,
|
||||
"prompt": question,
|
||||
"text": answer,
|
||||
"answer_id": ans_id,
|
||||
"model_id": model_config.arch,
|
||||
"metadata": {}}) + "\n")
|
||||
answer_file.flush()
|
||||
answer_file.close()
|
||||
|
@ -1,211 +0,0 @@
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from imagebind.models.image_bind import ModalityType
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
# system_img: List[Image.Image] = []
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
# system_img=self.system_img,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
# "system_img": self.system_img,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Vision>ImageContent</Vision>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
# NOTE: the hard code for postfix is removed.
|
||||
# TODO: Need to be compatible with more modalities.
|
||||
end_token = '</Vision>'
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
# Generate an answer written by LLaMA
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
do_sample=True,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0: # the model might output a unknown token <unk> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
# Upload Image, Encode Image and Create a new message from human.
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
|
||||
all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
|
||||
image_emb = all_embeddings[ModalityType.VISION]
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], "<Vision><VisionHere></Vision>")
|
||||
msg = "Received."
|
||||
# self.conv.append_message(self.conv.roles[1], msg)
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
# Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<VisionHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
|
@ -45,15 +45,6 @@ class BindGPT4(BaseModel):
|
||||
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
|
||||
print('Loading ImageBind Done')
|
||||
|
||||
print('Loading Q-Former and Adapter/Projector')
|
||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||
vision_qformer_frozen=freeze_qformer,
|
||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||
# vision_qformer_model=q_former_model,
|
||||
# vision_pre_dims=(1280, 1408)
|
||||
)
|
||||
print('Loading Q-Former and Adapter/Projector Done')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||
@ -67,6 +58,15 @@ class BindGPT4(BaseModel):
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
print('Loading Q-Former and Adapter/Projector')
|
||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||
vision_qformer_frozen=freeze_qformer,
|
||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||
# vision_qformer_model=q_former_model,
|
||||
# vision_pre_dims=(1280, 1408)
|
||||
)
|
||||
print('Loading Q-Former and Adapter/Projector Done')
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
@ -93,7 +93,7 @@ class BindGPT4(BaseModel):
|
||||
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
||||
if prompt:
|
||||
batch_size = input_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
|
||||
p_before, p_after = prompt.split('<ModalityHere>')
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
|
@ -1,4 +1,4 @@
|
||||
<Vision><VisionHere></Vision> Describe this image in detail.
|
||||
<Vision><VisionHere></Vision> Take a look at this image and describe what you notice.
|
||||
<Vision><VisionHere></Vision> Please provide a detailed description of the picture.
|
||||
<Vision><VisionHere></Vision> Could you describe the contents of this image for me?
|
||||
<Vision><ModalityHere></Vision> Describe this image in detail.
|
||||
<Vision><ModalityHere></Vision> Take a look at this image and describe what you notice.
|
||||
<Vision><ModalityHere></Vision> Please provide a detailed description of the picture.
|
||||
<Vision><ModalityHere></Vision> Could you describe the contents of this image for me?
|
@ -1,7 +1,7 @@
|
||||
model:
|
||||
arch: bind_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_imagebind: True
|
||||
freeze_qformer: False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user