mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
init version of v2
This commit is contained in:
parent
d1367e5e64
commit
7a575af639
765
demo_v2.py
Normal file
765
demo_v2.py
Normal file
@ -0,0 +1,765 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from threading import Thread
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
from termcolor import colored
|
||||
from textwrap import wrap
|
||||
from torchvision.transforms import functional as F
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import html
|
||||
import gradio as gr
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
|
||||
import minigpt4.tasks as tasks
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
||||
from minigpt4.common.logger import setup_logger
|
||||
from minigpt4.common.optims import (
|
||||
LinearWarmupCosineLRScheduler,
|
||||
LinearWarmupStepLRScheduler,
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.common.utils import now
|
||||
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
print('Initializing Chat')
|
||||
cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
|
||||
cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
|
||||
cfg.model_cfg.lora_r = 64
|
||||
cfg.model_cfg.lora_alpha = 16
|
||||
|
||||
device = 'cuda'
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to(device)
|
||||
bounding_box_size = 100
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="",
|
||||
roles=(r"<s>[INST] ", r" [/INST]"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="</s>",
|
||||
)
|
||||
|
||||
|
||||
def extract_substrings(string):
|
||||
# first check if there is no-finished bracket
|
||||
index = string.rfind('}')
|
||||
if index != -1:
|
||||
string = string[:index + 1]
|
||||
|
||||
pattern = r'<p>(.*?)\}(?!<)'
|
||||
matches = re.findall(pattern, string)
|
||||
substrings = [match for match in matches]
|
||||
|
||||
return substrings
|
||||
|
||||
|
||||
def is_overlapping(rect1, rect2):
|
||||
x1, y1, x2, y2 = rect1
|
||||
x3, y3, x4, y4 = rect2
|
||||
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
||||
|
||||
|
||||
def computeIoU(bbox1, bbox2):
|
||||
x1, y1, x2, y2 = bbox1
|
||||
x3, y3, x4, y4 = bbox2
|
||||
intersection_x1 = max(x1, x3)
|
||||
intersection_y1 = max(y1, y3)
|
||||
intersection_x2 = min(x2, x4)
|
||||
intersection_y2 = min(y2, y4)
|
||||
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
||||
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
||||
union_area = bbox1_area + bbox2_area - intersection_area
|
||||
iou = intersection_area / union_area
|
||||
return iou
|
||||
|
||||
|
||||
def save_tmp_img(visual_img):
|
||||
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
|
||||
file_path = "/tmp/" + file_name
|
||||
visual_img.save(file_path)
|
||||
return file_path
|
||||
|
||||
|
||||
def mask2bbox(mask):
|
||||
if mask is None:
|
||||
return ''
|
||||
mask = mask.resize([100, 100], resample=Image.NEAREST)
|
||||
mask = np.array(mask)[:, :, 0]
|
||||
|
||||
rows = np.any(mask, axis=1)
|
||||
cols = np.any(mask, axis=0)
|
||||
|
||||
if rows.sum():
|
||||
# Get the top, bottom, left, and right boundaries
|
||||
rmin, rmax = np.where(rows)[0][[0, -1]]
|
||||
cmin, cmax = np.where(cols)[0][[0, -1]]
|
||||
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
|
||||
else:
|
||||
bbox = ''
|
||||
|
||||
return bbox
|
||||
|
||||
|
||||
def escape_markdown(text):
|
||||
# List of Markdown special characters that need to be escaped
|
||||
md_chars = ['<', '>']
|
||||
|
||||
# Escape each special character
|
||||
for char in md_chars:
|
||||
text = text.replace(char, '\\' + char)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def reverse_escape(text):
|
||||
md_chars = ['\\<', '\\>']
|
||||
|
||||
for char in md_chars:
|
||||
text = text.replace(char, char[1:])
|
||||
|
||||
return text
|
||||
|
||||
|
||||
colors = [
|
||||
(255, 0, 0),
|
||||
(0, 255, 0),
|
||||
(0, 0, 255),
|
||||
(210, 210, 0),
|
||||
(255, 0, 255),
|
||||
(0, 255, 255),
|
||||
(114, 128, 250),
|
||||
(0, 165, 255),
|
||||
(0, 128, 0),
|
||||
(144, 238, 144),
|
||||
(238, 238, 175),
|
||||
(255, 191, 0),
|
||||
(0, 128, 0),
|
||||
(226, 43, 138),
|
||||
(255, 0, 255),
|
||||
(0, 215, 255),
|
||||
]
|
||||
|
||||
color_map = {
|
||||
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
|
||||
color_id, color in enumerate(colors)
|
||||
}
|
||||
|
||||
used_colors = colors
|
||||
|
||||
|
||||
def visualize_all_bbox_together(image, generation):
|
||||
if image is None:
|
||||
return None, ''
|
||||
|
||||
generation = html.unescape(generation)
|
||||
print('gen begin', generation)
|
||||
|
||||
image_width, image_height = image.size
|
||||
image = image.resize([500, int(500 / image_width * image_height)])
|
||||
image_width, image_height = image.size
|
||||
|
||||
string_list = extract_substrings(generation)
|
||||
if string_list: # it is grounding or detection
|
||||
mode = 'all'
|
||||
entities = defaultdict(list)
|
||||
i = 0
|
||||
j = 0
|
||||
for string in string_list:
|
||||
try:
|
||||
obj, string = string.split('</p>')
|
||||
except ValueError:
|
||||
print('wrong string: ', string)
|
||||
continue
|
||||
bbox_list = string.split('<delim>')
|
||||
flag = False
|
||||
for bbox_string in bbox_list:
|
||||
integers = re.findall(r'-?\d+', bbox_string)
|
||||
if len(integers) == 4:
|
||||
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||
left = x0 / bounding_box_size * image_width
|
||||
bottom = y0 / bounding_box_size * image_height
|
||||
right = x1 / bounding_box_size * image_width
|
||||
top = y1 / bounding_box_size * image_height
|
||||
|
||||
entities[obj].append([left, bottom, right, top])
|
||||
|
||||
j += 1
|
||||
flag = True
|
||||
if flag:
|
||||
i += 1
|
||||
else:
|
||||
integers = re.findall(r'-?\d+', generation)
|
||||
|
||||
if len(integers) == 4: # it is refer
|
||||
mode = 'single'
|
||||
|
||||
entities = list()
|
||||
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||
left = x0 / bounding_box_size * image_width
|
||||
bottom = y0 / bounding_box_size * image_height
|
||||
right = x1 / bounding_box_size * image_width
|
||||
top = y1 / bounding_box_size * image_height
|
||||
entities.append([left, bottom, right, top])
|
||||
else:
|
||||
# don't detect any valid bbox to visualize
|
||||
return None, ''
|
||||
|
||||
if len(entities) == 0:
|
||||
return None, ''
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image_h = image.height
|
||||
image_w = image.width
|
||||
image = np.array(image)
|
||||
|
||||
elif isinstance(image, str):
|
||||
if os.path.exists(image):
|
||||
pil_img = Image.open(image).convert("RGB")
|
||||
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||
image_h = pil_img.height
|
||||
image_w = pil_img.width
|
||||
else:
|
||||
raise ValueError(f"invaild image path, {image}")
|
||||
elif isinstance(image, torch.Tensor):
|
||||
|
||||
image_tensor = image.cpu()
|
||||
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
|
||||
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
|
||||
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
|
||||
pil_img = T.ToPILImage()(image_tensor)
|
||||
image_h = pil_img.height
|
||||
image_w = pil_img.width
|
||||
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||
else:
|
||||
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
||||
|
||||
indices = list(range(len(entities)))
|
||||
|
||||
new_image = image.copy()
|
||||
|
||||
previous_bboxes = []
|
||||
# size of text
|
||||
text_size = 0.5
|
||||
# thickness of text
|
||||
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
|
||||
box_line = 2
|
||||
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
||||
base_height = int(text_height * 0.675)
|
||||
text_offset_original = text_height - base_height
|
||||
text_spaces = 2
|
||||
|
||||
# num_bboxes = sum(len(x[-1]) for x in entities)
|
||||
used_colors = colors # random.sample(colors, k=num_bboxes)
|
||||
|
||||
color_id = -1
|
||||
for entity_idx, entity_name in enumerate(entities):
|
||||
if mode == 'single' or mode == 'identify':
|
||||
bboxes = entity_name
|
||||
bboxes = [bboxes]
|
||||
else:
|
||||
bboxes = entities[entity_name]
|
||||
color_id += 1
|
||||
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
|
||||
skip_flag = False
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
|
||||
|
||||
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
|
||||
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
||||
|
||||
if mode == 'all':
|
||||
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
||||
|
||||
x1 = orig_x1 - l_o
|
||||
y1 = orig_y1 - l_o
|
||||
|
||||
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
||||
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
||||
x1 = orig_x1 + r_o
|
||||
|
||||
# add text background
|
||||
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
|
||||
text_line)
|
||||
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
|
||||
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
||||
|
||||
for prev_bbox in previous_bboxes:
|
||||
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
|
||||
prev_bbox['phrase'] == entity_name:
|
||||
skip_flag = True
|
||||
break
|
||||
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
|
||||
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
||||
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||
|
||||
if text_bg_y2 >= image_h:
|
||||
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
||||
text_bg_y2 = image_h
|
||||
y1 = image_h
|
||||
break
|
||||
if not skip_flag:
|
||||
alpha = 0.5
|
||||
for i in range(text_bg_y1, text_bg_y2):
|
||||
for j in range(text_bg_x1, text_bg_x2):
|
||||
if i < image_h and j < image_w:
|
||||
if j < text_bg_x1 + 1.35 * c_width:
|
||||
# original color
|
||||
bg_color = color
|
||||
else:
|
||||
# white
|
||||
bg_color = [255, 255, 255]
|
||||
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
|
||||
np.uint8)
|
||||
|
||||
cv2.putText(
|
||||
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
|
||||
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
||||
)
|
||||
|
||||
previous_bboxes.append(
|
||||
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
|
||||
|
||||
if mode == 'all':
|
||||
def color_iterator(colors):
|
||||
while True:
|
||||
for color in colors:
|
||||
yield color
|
||||
|
||||
color_gen = color_iterator(colors)
|
||||
|
||||
# Add colors to phrases and remove <p></p>
|
||||
def colored_phrases(match):
|
||||
phrase = match.group(1)
|
||||
color = next(color_gen)
|
||||
return f'<span style="color:rgb{color}">{phrase}</span>'
|
||||
|
||||
print('gen before', generation)
|
||||
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
||||
print('gen after', generation)
|
||||
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
||||
else:
|
||||
generation_colored = ''
|
||||
|
||||
pil_image = Image.fromarray(new_image)
|
||||
return pil_image, generation_colored
|
||||
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
|
||||
interactive=True), chat_state, img_list
|
||||
|
||||
|
||||
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
||||
# set the upload flag to true when receive a new image.
|
||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||
print('flag', upload_flag, replace_flag)
|
||||
print("SET UPLOAD FLAG!")
|
||||
upload_flag = 1
|
||||
if img_list:
|
||||
print("SET REPLACE FLAG!")
|
||||
replace_flag = 1
|
||||
print('flag', upload_flag, replace_flag)
|
||||
return upload_flag, replace_flag
|
||||
|
||||
|
||||
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
||||
# set the upload flag to true when receive a new image.
|
||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||
print('flag', upload_flag, replace_flag)
|
||||
print("SET UPLOAD FLAG!")
|
||||
upload_flag = 1
|
||||
if img_list or replace_flag == 1:
|
||||
print("SET REPLACE FLAG!")
|
||||
replace_flag = 1
|
||||
|
||||
print('flag', upload_flag, replace_flag)
|
||||
return upload_flag, replace_flag
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
||||
if isinstance(gr_img, dict):
|
||||
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if '[identify]' in user_message:
|
||||
# check if user provide bbox in the text input
|
||||
integers = re.findall(r'-?\d+', user_message)
|
||||
if len(integers) != 4: # no bbox in text
|
||||
bbox = mask2bbox(mask)
|
||||
user_message = user_message + bbox
|
||||
|
||||
if len(user_message) == 0:
|
||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||
|
||||
if chat_state is None:
|
||||
chat_state = CONV_VISION.copy()
|
||||
|
||||
print('upload flag: {}'.format(upload_flag))
|
||||
if upload_flag:
|
||||
if replace_flag:
|
||||
print('RESET!!!!!!!')
|
||||
chat_state = CONV_VISION.copy() # new image, reset everything
|
||||
replace_flag = 0
|
||||
chatbot = []
|
||||
print('UPLOAD IMAGE!!')
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
upload_flag = 0
|
||||
|
||||
chat.ask(user_message, chat_state)
|
||||
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
|
||||
if '[identify]' in user_message:
|
||||
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
||||
if visual_img is not None:
|
||||
print('Visualizing the input')
|
||||
file_path = save_tmp_img(visual_img)
|
||||
chatbot = chatbot + [[(file_path,), None]]
|
||||
|
||||
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||
llm_message = chat.answer(conv=chat_state,
|
||||
img_list=img_list,
|
||||
temperature=temperature,
|
||||
max_new_tokens=500,
|
||||
max_length=2000)[0]
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||
if not isinstance(img_list[0], torch.Tensor):
|
||||
chat.encode_img(img_list)
|
||||
streamer = chat.stream_answer(conv=chat_state,
|
||||
img_list=img_list,
|
||||
temperature=temperature,
|
||||
max_new_tokens=500,
|
||||
max_length=2000)
|
||||
output = ''
|
||||
for new_output in streamer:
|
||||
escapped = escape_markdown(new_output)
|
||||
output += escapped
|
||||
chatbot[-1][1] = output
|
||||
yield chatbot, chat_state
|
||||
# print('message: ', chat_state.messages)
|
||||
chat_state.messages[-1][1] = reverse_escape(output) + '</s>'
|
||||
return chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_visualize(chatbot, gr_img):
|
||||
if isinstance(gr_img, dict):
|
||||
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||
|
||||
unescaped = reverse_escape(chatbot[-1][1])
|
||||
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
||||
if visual_img is not None:
|
||||
print('Visualizing the output')
|
||||
if len(generation_color):
|
||||
chatbot[-1][1] = generation_color
|
||||
file_path = save_tmp_img(visual_img)
|
||||
chatbot = chatbot + [[None, (file_path,)]]
|
||||
|
||||
return chatbot
|
||||
|
||||
|
||||
def gradio_taskselect(idx):
|
||||
prompt_list = [
|
||||
'',
|
||||
'[grounding] describe this image in detail',
|
||||
'[refer] ',
|
||||
'[detection] ',
|
||||
'[identify] what is this ',
|
||||
'[vqa] '
|
||||
]
|
||||
instruct_list = [
|
||||
'**Hint:** Type in whatever you want',
|
||||
'**Hint:** Send the command to generate a grounded image description',
|
||||
'**Hint:** Type in a phrase about an object in the image and send the command',
|
||||
'**Hint:** Type in a caption or phrase, and see object locations in the image',
|
||||
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
|
||||
'**Hint:** Send a question to get a short answer',
|
||||
]
|
||||
return prompt_list[idx], instruct_list[idx]
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
generation_kwargs = dict(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
do_sample=True,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
return generation_kwargs
|
||||
|
||||
def answer(self, conv, img_list, **kargs):
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
|
||||
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def stream_answer(self, conv, img_list, **kargs):
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||
generation_kwargs['streamer'] = streamer
|
||||
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
return streamer
|
||||
|
||||
def encode_img(self, img_list):
|
||||
image = img_list[0]
|
||||
img_list.pop(0)
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
img_list.append(image_emb)
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||
img_list.append(image)
|
||||
msg = "Received."
|
||||
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
|
||||
|
||||
chat = Chat(model, vis_processor, device=device)
|
||||
|
||||
title = '**MiniGPT-v2 Demo**'
|
||||
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
|
||||
article = 'demo'
|
||||
|
||||
introduction = '''
|
||||
For Abilities Involving Visual Grounding:
|
||||
1. Grounding: CLICK **Send** to generate a grounded image description.
|
||||
2. Refer: Input a referring object and CLICK **Send**.
|
||||
3. Detection: Write a caption or phrase, and CLICK **Send**.
|
||||
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
|
||||
5. VQA: Input a visual question and CLICK **Send**.
|
||||
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
|
||||
|
||||
You can also simply chat in free form!
|
||||
'''
|
||||
|
||||
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
|
||||
scale=8)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
gr.Markdown(article)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0.1,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
gr.Markdown(introduction)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State(value=None)
|
||||
img_list = gr.State(value=[])
|
||||
chatbot = gr.Chatbot(label='MiniGPT-v2')
|
||||
|
||||
dataset = gr.Dataset(
|
||||
components=[gr.Textbox(visible=False)],
|
||||
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
|
||||
type="index",
|
||||
label='Task Shortcuts',
|
||||
)
|
||||
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
|
||||
with gr.Row():
|
||||
text_input.render()
|
||||
send = gr.Button("Send", variant='primary', size='sm', scale=1)
|
||||
|
||||
upload_flag = gr.State(value=0)
|
||||
replace_flag = gr.State(value=0)
|
||||
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Examples(examples=[
|
||||
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
|
||||
img_list],
|
||||
["examples_v2/sofa.jpg", "[detection] sofa", upload_flag, replace_flag, img_list],
|
||||
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
|
||||
img_list],
|
||||
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
|
||||
replace_flag, img_list],
|
||||
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||
outputs=[upload_flag, replace_flag])
|
||||
with gr.Column():
|
||||
gr.Examples(examples=[
|
||||
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
|
||||
upload_flag, replace_flag, img_list],
|
||||
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
|
||||
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
|
||||
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
|
||||
replace_flag, img_list],
|
||||
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||
outputs=[upload_flag, replace_flag])
|
||||
|
||||
dataset.click(
|
||||
gradio_taskselect,
|
||||
inputs=[dataset],
|
||||
outputs=[text_input, task_inst],
|
||||
show_progress="hidden",
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
text_input.submit(
|
||||
gradio_ask,
|
||||
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||
).success(
|
||||
gradio_stream_answer,
|
||||
[chatbot, chat_state, img_list, temperature],
|
||||
[chatbot, chat_state]
|
||||
).success(
|
||||
gradio_visualize,
|
||||
[chatbot, image],
|
||||
[chatbot],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
send.click(
|
||||
gradio_ask,
|
||||
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||
).success(
|
||||
gradio_stream_answer,
|
||||
[chatbot, chat_state, img_list, temperature],
|
||||
[chatbot, chat_state]
|
||||
).success(
|
||||
gradio_visualize,
|
||||
[chatbot, image],
|
||||
[chatbot],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
|
||||
|
||||
demo.launch(share=True, enable_queue=True)
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_vicuna0
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_llama2
|
||||
max_txt_len: 160
|
||||
end_sym: "</s>"
|
||||
|
25
eval_configs/minigptv2_eval.yaml
Normal file
25
eval_configs/minigptv2_eval.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
model:
|
||||
arch: minigpt_v2
|
||||
model_type: pretrain
|
||||
max_txt_len: 160
|
||||
end_sym: "</s>"
|
||||
low_resource: True
|
||||
prompt_template: '[INST] {} [/INST]'
|
||||
ckpt: '/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth'
|
||||
llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 448
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
|
31
minigpt4/configs/models/minigpt_v2.yaml
Executable file
31
minigpt4/configs/models/minigpt_v2.yaml
Executable file
@ -0,0 +1,31 @@
|
||||
model:
|
||||
arch: minigpt_v2
|
||||
|
||||
# vit encoder
|
||||
image_size: 448
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
llama_model: "/path/to/llama2/weight"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 448
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 448
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
@ -11,14 +11,18 @@ from omegaconf import OmegaConf
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel
|
||||
from minigpt4.models.minigpt_4 import MiniGPT4
|
||||
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||
from minigpt4.models.minigpt4 import MiniGPT4
|
||||
from minigpt4.models.minigpt_v2 import MiniGPTv2
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"BaseModel",
|
||||
"MiniGPTBase",
|
||||
"MiniGPT4",
|
||||
"MiniGPTv2"
|
||||
]
|
||||
|
||||
|
||||
|
@ -11,8 +11,7 @@ from minigpt4.models.minigpt_base import MiniGPTBase
|
||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4")
|
||||
@registry.register_model("minigpt4")
|
||||
class MiniGPT4(MiniGPTBase):
|
||||
"""
|
||||
MiniGPT-4 model
|
@ -25,6 +25,8 @@ class MiniGPTBase(BaseModel):
|
||||
freeze_vit=True,
|
||||
llama_model="",
|
||||
max_txt_len=32,
|
||||
max_context_len=3800,
|
||||
prompt_template="",
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
@ -50,8 +52,10 @@ class MiniGPTBase(BaseModel):
|
||||
)
|
||||
|
||||
self.max_txt_len = max_txt_len
|
||||
self.max_context_len = max_context_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
self.prompt_template = prompt_template
|
||||
self.prompt_list = []
|
||||
|
||||
def vit_to_cpu(self):
|
||||
@ -129,7 +133,6 @@ class MiniGPTBase(BaseModel):
|
||||
wrapped_atts[i, :length] = 1
|
||||
return wrapped_embs, wrapped_atts
|
||||
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
"""
|
||||
Concatenate the batched input embedding and batched output embedding together.
|
||||
@ -219,7 +222,7 @@ class MiniGPTBase(BaseModel):
|
||||
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||
|
||||
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
|
||||
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
||||
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
||||
@ -233,7 +236,7 @@ class MiniGPTBase(BaseModel):
|
||||
instruction = None
|
||||
|
||||
if self.chat_template:
|
||||
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
|
||||
|
||||
if 'length' in samples:
|
||||
# the input is a image train (like videos)
|
||||
|
139
minigpt4/models/minigpt_v2.py
Normal file
139
minigpt4/models/minigpt_v2.py
Normal file
@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast as autocast
|
||||
import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import disabled_train
|
||||
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||
|
||||
|
||||
@registry.register_model("minigpt_v2")
|
||||
class MiniGPTv2(MiniGPTBase):
|
||||
"""
|
||||
MiniGPT-v2 model
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain": "configs/models/minigpt_v2.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
img_size=448,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
llama_model="",
|
||||
prompt_template='[INST] {} [/INST]',
|
||||
max_txt_len=300,
|
||||
end_sym='\n',
|
||||
lora_r=64,
|
||||
lora_target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
chat_template=False,
|
||||
use_grad_checkpoint_llm=False,
|
||||
max_context_len=3800,
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
):
|
||||
super().__init__(
|
||||
vit_model=vit_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
llama_model=llama_model,
|
||||
max_txt_len=max_txt_len,
|
||||
max_context_len=max_context_len,
|
||||
end_sym=end_sym,
|
||||
prompt_template=prompt_template,
|
||||
low_resource=low_resource,
|
||||
device_8bit=device_8bit,
|
||||
lora_r=lora_r,
|
||||
lora_target_modules=lora_target_modules,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
self.llama_proj = nn.Linear(
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.chat_template = chat_template
|
||||
|
||||
if use_grad_checkpoint_llm:
|
||||
self.llama_model.gradient_checkpointing_enable()
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
|
||||
if len(image.shape) > 4:
|
||||
image = image.reshape(-1, *image.shape[-3:])
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_embeds = image_embeds[:, 1:, :]
|
||||
bs, pn, hs = image_embeds.shape
|
||||
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||
|
||||
inputs_llama = self.llama_proj(image_embeds)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
img_size = cfg.get("image_size")
|
||||
llama_model = cfg.get("llama_model")
|
||||
|
||||
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
low_resource = cfg.get("low_resource", False)
|
||||
|
||||
prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
|
||||
max_txt_len = cfg.get("max_txt_len", 300)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
lora_r = cfg.get("lora_r", 64)
|
||||
lora_alpha = cfg.get("lora_alpha", 16)
|
||||
chat_template = cfg.get("chat_template", False)
|
||||
|
||||
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||
max_context_len = cfg.get("max_context_len", 3800)
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
llama_model=llama_model,
|
||||
prompt_template=prompt_template,
|
||||
max_txt_len=max_txt_len,
|
||||
low_resource=low_resource,
|
||||
end_sym=end_sym,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
chat_template=chat_template,
|
||||
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||
max_context_len=max_context_len,
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
if ckpt_path:
|
||||
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||
|
||||
return model
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_llama2
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_llama2
|
||||
|
||||
max_txt_len: 160
|
||||
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_vicuna0
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
arch: minigpt4
|
||||
model_type: pretrain_vicuna0
|
||||
|
||||
max_txt_len: 160
|
||||
|
Loading…
Reference in New Issue
Block a user