mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
init version of v2
This commit is contained in:
parent
d1367e5e64
commit
7a575af639
765
demo_v2.py
Normal file
765
demo_v2.py
Normal file
@ -0,0 +1,765 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
from threading import Thread
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from termcolor import colored
|
||||||
|
from textwrap import wrap
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import html
|
||||||
|
import gradio as gr
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
|
||||||
|
import minigpt4.tasks as tasks
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
||||||
|
from minigpt4.common.logger import setup_logger
|
||||||
|
from minigpt4.common.optims import (
|
||||||
|
LinearWarmupCosineLRScheduler,
|
||||||
|
LinearWarmupStepLRScheduler,
|
||||||
|
)
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.common.utils import now
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||||
|
|
||||||
|
# imports modules for registration
|
||||||
|
from minigpt4.datasets.builders import *
|
||||||
|
from minigpt4.models import *
|
||||||
|
from minigpt4.processors import *
|
||||||
|
from minigpt4.runners import *
|
||||||
|
from minigpt4.tasks import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Demo")
|
||||||
|
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
cudnn.benchmark = False
|
||||||
|
cudnn.deterministic = True
|
||||||
|
|
||||||
|
print('Initializing Chat')
|
||||||
|
cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
|
||||||
|
cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
|
||||||
|
cfg.model_cfg.lora_r = 64
|
||||||
|
cfg.model_cfg.lora_alpha = 16
|
||||||
|
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to(device)
|
||||||
|
bounding_box_size = 100
|
||||||
|
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
CONV_VISION = Conversation(
|
||||||
|
system="",
|
||||||
|
roles=(r"<s>[INST] ", r" [/INST]"),
|
||||||
|
messages=[],
|
||||||
|
offset=2,
|
||||||
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
|
sep="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_substrings(string):
|
||||||
|
# first check if there is no-finished bracket
|
||||||
|
index = string.rfind('}')
|
||||||
|
if index != -1:
|
||||||
|
string = string[:index + 1]
|
||||||
|
|
||||||
|
pattern = r'<p>(.*?)\}(?!<)'
|
||||||
|
matches = re.findall(pattern, string)
|
||||||
|
substrings = [match for match in matches]
|
||||||
|
|
||||||
|
return substrings
|
||||||
|
|
||||||
|
|
||||||
|
def is_overlapping(rect1, rect2):
|
||||||
|
x1, y1, x2, y2 = rect1
|
||||||
|
x3, y3, x4, y4 = rect2
|
||||||
|
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
||||||
|
|
||||||
|
|
||||||
|
def computeIoU(bbox1, bbox2):
|
||||||
|
x1, y1, x2, y2 = bbox1
|
||||||
|
x3, y3, x4, y4 = bbox2
|
||||||
|
intersection_x1 = max(x1, x3)
|
||||||
|
intersection_y1 = max(y1, y3)
|
||||||
|
intersection_x2 = min(x2, x4)
|
||||||
|
intersection_y2 = min(y2, y4)
|
||||||
|
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
||||||
|
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
||||||
|
union_area = bbox1_area + bbox2_area - intersection_area
|
||||||
|
iou = intersection_area / union_area
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def save_tmp_img(visual_img):
|
||||||
|
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
|
||||||
|
file_path = "/tmp/" + file_name
|
||||||
|
visual_img.save(file_path)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def mask2bbox(mask):
|
||||||
|
if mask is None:
|
||||||
|
return ''
|
||||||
|
mask = mask.resize([100, 100], resample=Image.NEAREST)
|
||||||
|
mask = np.array(mask)[:, :, 0]
|
||||||
|
|
||||||
|
rows = np.any(mask, axis=1)
|
||||||
|
cols = np.any(mask, axis=0)
|
||||||
|
|
||||||
|
if rows.sum():
|
||||||
|
# Get the top, bottom, left, and right boundaries
|
||||||
|
rmin, rmax = np.where(rows)[0][[0, -1]]
|
||||||
|
cmin, cmax = np.where(cols)[0][[0, -1]]
|
||||||
|
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
|
||||||
|
else:
|
||||||
|
bbox = ''
|
||||||
|
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
|
||||||
|
def escape_markdown(text):
|
||||||
|
# List of Markdown special characters that need to be escaped
|
||||||
|
md_chars = ['<', '>']
|
||||||
|
|
||||||
|
# Escape each special character
|
||||||
|
for char in md_chars:
|
||||||
|
text = text.replace(char, '\\' + char)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_escape(text):
|
||||||
|
md_chars = ['\\<', '\\>']
|
||||||
|
|
||||||
|
for char in md_chars:
|
||||||
|
text = text.replace(char, char[1:])
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
colors = [
|
||||||
|
(255, 0, 0),
|
||||||
|
(0, 255, 0),
|
||||||
|
(0, 0, 255),
|
||||||
|
(210, 210, 0),
|
||||||
|
(255, 0, 255),
|
||||||
|
(0, 255, 255),
|
||||||
|
(114, 128, 250),
|
||||||
|
(0, 165, 255),
|
||||||
|
(0, 128, 0),
|
||||||
|
(144, 238, 144),
|
||||||
|
(238, 238, 175),
|
||||||
|
(255, 191, 0),
|
||||||
|
(0, 128, 0),
|
||||||
|
(226, 43, 138),
|
||||||
|
(255, 0, 255),
|
||||||
|
(0, 215, 255),
|
||||||
|
]
|
||||||
|
|
||||||
|
color_map = {
|
||||||
|
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
|
||||||
|
color_id, color in enumerate(colors)
|
||||||
|
}
|
||||||
|
|
||||||
|
used_colors = colors
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_all_bbox_together(image, generation):
|
||||||
|
if image is None:
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
generation = html.unescape(generation)
|
||||||
|
print('gen begin', generation)
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
image = image.resize([500, int(500 / image_width * image_height)])
|
||||||
|
image_width, image_height = image.size
|
||||||
|
|
||||||
|
string_list = extract_substrings(generation)
|
||||||
|
if string_list: # it is grounding or detection
|
||||||
|
mode = 'all'
|
||||||
|
entities = defaultdict(list)
|
||||||
|
i = 0
|
||||||
|
j = 0
|
||||||
|
for string in string_list:
|
||||||
|
try:
|
||||||
|
obj, string = string.split('</p>')
|
||||||
|
except ValueError:
|
||||||
|
print('wrong string: ', string)
|
||||||
|
continue
|
||||||
|
bbox_list = string.split('<delim>')
|
||||||
|
flag = False
|
||||||
|
for bbox_string in bbox_list:
|
||||||
|
integers = re.findall(r'-?\d+', bbox_string)
|
||||||
|
if len(integers) == 4:
|
||||||
|
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||||
|
left = x0 / bounding_box_size * image_width
|
||||||
|
bottom = y0 / bounding_box_size * image_height
|
||||||
|
right = x1 / bounding_box_size * image_width
|
||||||
|
top = y1 / bounding_box_size * image_height
|
||||||
|
|
||||||
|
entities[obj].append([left, bottom, right, top])
|
||||||
|
|
||||||
|
j += 1
|
||||||
|
flag = True
|
||||||
|
if flag:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
integers = re.findall(r'-?\d+', generation)
|
||||||
|
|
||||||
|
if len(integers) == 4: # it is refer
|
||||||
|
mode = 'single'
|
||||||
|
|
||||||
|
entities = list()
|
||||||
|
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||||
|
left = x0 / bounding_box_size * image_width
|
||||||
|
bottom = y0 / bounding_box_size * image_height
|
||||||
|
right = x1 / bounding_box_size * image_width
|
||||||
|
top = y1 / bounding_box_size * image_height
|
||||||
|
entities.append([left, bottom, right, top])
|
||||||
|
else:
|
||||||
|
# don't detect any valid bbox to visualize
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
if len(entities) == 0:
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image_h = image.height
|
||||||
|
image_w = image.width
|
||||||
|
image = np.array(image)
|
||||||
|
|
||||||
|
elif isinstance(image, str):
|
||||||
|
if os.path.exists(image):
|
||||||
|
pil_img = Image.open(image).convert("RGB")
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image path, {image}")
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
|
||||||
|
image_tensor = image.cpu()
|
||||||
|
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
|
||||||
|
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
|
||||||
|
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
|
||||||
|
pil_img = T.ToPILImage()(image_tensor)
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
||||||
|
|
||||||
|
indices = list(range(len(entities)))
|
||||||
|
|
||||||
|
new_image = image.copy()
|
||||||
|
|
||||||
|
previous_bboxes = []
|
||||||
|
# size of text
|
||||||
|
text_size = 0.5
|
||||||
|
# thickness of text
|
||||||
|
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
|
||||||
|
box_line = 2
|
||||||
|
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
||||||
|
base_height = int(text_height * 0.675)
|
||||||
|
text_offset_original = text_height - base_height
|
||||||
|
text_spaces = 2
|
||||||
|
|
||||||
|
# num_bboxes = sum(len(x[-1]) for x in entities)
|
||||||
|
used_colors = colors # random.sample(colors, k=num_bboxes)
|
||||||
|
|
||||||
|
color_id = -1
|
||||||
|
for entity_idx, entity_name in enumerate(entities):
|
||||||
|
if mode == 'single' or mode == 'identify':
|
||||||
|
bboxes = entity_name
|
||||||
|
bboxes = [bboxes]
|
||||||
|
else:
|
||||||
|
bboxes = entities[entity_name]
|
||||||
|
color_id += 1
|
||||||
|
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
|
||||||
|
skip_flag = False
|
||||||
|
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
|
||||||
|
|
||||||
|
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
|
||||||
|
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
||||||
|
|
||||||
|
if mode == 'all':
|
||||||
|
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
||||||
|
|
||||||
|
x1 = orig_x1 - l_o
|
||||||
|
y1 = orig_y1 - l_o
|
||||||
|
|
||||||
|
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
||||||
|
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
||||||
|
x1 = orig_x1 + r_o
|
||||||
|
|
||||||
|
# add text background
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
|
||||||
|
text_line)
|
||||||
|
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
|
||||||
|
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
||||||
|
|
||||||
|
for prev_bbox in previous_bboxes:
|
||||||
|
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
|
||||||
|
prev_bbox['phrase'] == entity_name:
|
||||||
|
skip_flag = True
|
||||||
|
break
|
||||||
|
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
|
||||||
|
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
|
||||||
|
if text_bg_y2 >= image_h:
|
||||||
|
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
||||||
|
text_bg_y2 = image_h
|
||||||
|
y1 = image_h
|
||||||
|
break
|
||||||
|
if not skip_flag:
|
||||||
|
alpha = 0.5
|
||||||
|
for i in range(text_bg_y1, text_bg_y2):
|
||||||
|
for j in range(text_bg_x1, text_bg_x2):
|
||||||
|
if i < image_h and j < image_w:
|
||||||
|
if j < text_bg_x1 + 1.35 * c_width:
|
||||||
|
# original color
|
||||||
|
bg_color = color
|
||||||
|
else:
|
||||||
|
# white
|
||||||
|
bg_color = [255, 255, 255]
|
||||||
|
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
|
||||||
|
np.uint8)
|
||||||
|
|
||||||
|
cv2.putText(
|
||||||
|
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
|
||||||
|
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_bboxes.append(
|
||||||
|
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
|
||||||
|
|
||||||
|
if mode == 'all':
|
||||||
|
def color_iterator(colors):
|
||||||
|
while True:
|
||||||
|
for color in colors:
|
||||||
|
yield color
|
||||||
|
|
||||||
|
color_gen = color_iterator(colors)
|
||||||
|
|
||||||
|
# Add colors to phrases and remove <p></p>
|
||||||
|
def colored_phrases(match):
|
||||||
|
phrase = match.group(1)
|
||||||
|
color = next(color_gen)
|
||||||
|
return f'<span style="color:rgb{color}">{phrase}</span>'
|
||||||
|
|
||||||
|
print('gen before', generation)
|
||||||
|
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
||||||
|
print('gen after', generation)
|
||||||
|
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
||||||
|
else:
|
||||||
|
generation_colored = ''
|
||||||
|
|
||||||
|
pil_image = Image.fromarray(new_image)
|
||||||
|
return pil_image, generation_colored
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_reset(chat_state, img_list):
|
||||||
|
if chat_state is not None:
|
||||||
|
chat_state.messages = []
|
||||||
|
if img_list is not None:
|
||||||
|
img_list = []
|
||||||
|
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
|
||||||
|
interactive=True), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
|
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
||||||
|
# set the upload flag to true when receive a new image.
|
||||||
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
print("SET UPLOAD FLAG!")
|
||||||
|
upload_flag = 1
|
||||||
|
if img_list:
|
||||||
|
print("SET REPLACE FLAG!")
|
||||||
|
replace_flag = 1
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
||||||
|
# set the upload flag to true when receive a new image.
|
||||||
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
print("SET UPLOAD FLAG!")
|
||||||
|
upload_flag = 1
|
||||||
|
if img_list or replace_flag == 1:
|
||||||
|
print("SET REPLACE FLAG!")
|
||||||
|
replace_flag = 1
|
||||||
|
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
||||||
|
if isinstance(gr_img, dict):
|
||||||
|
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
if '[identify]' in user_message:
|
||||||
|
# check if user provide bbox in the text input
|
||||||
|
integers = re.findall(r'-?\d+', user_message)
|
||||||
|
if len(integers) != 4: # no bbox in text
|
||||||
|
bbox = mask2bbox(mask)
|
||||||
|
user_message = user_message + bbox
|
||||||
|
|
||||||
|
if len(user_message) == 0:
|
||||||
|
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||||
|
|
||||||
|
if chat_state is None:
|
||||||
|
chat_state = CONV_VISION.copy()
|
||||||
|
|
||||||
|
print('upload flag: {}'.format(upload_flag))
|
||||||
|
if upload_flag:
|
||||||
|
if replace_flag:
|
||||||
|
print('RESET!!!!!!!')
|
||||||
|
chat_state = CONV_VISION.copy() # new image, reset everything
|
||||||
|
replace_flag = 0
|
||||||
|
chatbot = []
|
||||||
|
print('UPLOAD IMAGE!!')
|
||||||
|
img_list = []
|
||||||
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
|
upload_flag = 0
|
||||||
|
|
||||||
|
chat.ask(user_message, chat_state)
|
||||||
|
|
||||||
|
chatbot = chatbot + [[user_message, None]]
|
||||||
|
|
||||||
|
if '[identify]' in user_message:
|
||||||
|
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
||||||
|
if visual_img is not None:
|
||||||
|
print('Visualizing the input')
|
||||||
|
file_path = save_tmp_img(visual_img)
|
||||||
|
chatbot = chatbot + [[(file_path,), None]]
|
||||||
|
|
||||||
|
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||||
|
llm_message = chat.answer(conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=500,
|
||||||
|
max_length=2000)[0]
|
||||||
|
chatbot[-1][1] = llm_message
|
||||||
|
return chatbot, chat_state
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||||
|
if not isinstance(img_list[0], torch.Tensor):
|
||||||
|
chat.encode_img(img_list)
|
||||||
|
streamer = chat.stream_answer(conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=500,
|
||||||
|
max_length=2000)
|
||||||
|
output = ''
|
||||||
|
for new_output in streamer:
|
||||||
|
escapped = escape_markdown(new_output)
|
||||||
|
output += escapped
|
||||||
|
chatbot[-1][1] = output
|
||||||
|
yield chatbot, chat_state
|
||||||
|
# print('message: ', chat_state.messages)
|
||||||
|
chat_state.messages[-1][1] = reverse_escape(output) + '</s>'
|
||||||
|
return chatbot, chat_state
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_visualize(chatbot, gr_img):
|
||||||
|
if isinstance(gr_img, dict):
|
||||||
|
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||||
|
|
||||||
|
unescaped = reverse_escape(chatbot[-1][1])
|
||||||
|
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
||||||
|
if visual_img is not None:
|
||||||
|
print('Visualizing the output')
|
||||||
|
if len(generation_color):
|
||||||
|
chatbot[-1][1] = generation_color
|
||||||
|
file_path = save_tmp_img(visual_img)
|
||||||
|
chatbot = chatbot + [[None, (file_path,)]]
|
||||||
|
|
||||||
|
return chatbot
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_taskselect(idx):
|
||||||
|
prompt_list = [
|
||||||
|
'',
|
||||||
|
'[grounding] describe this image in detail',
|
||||||
|
'[refer] ',
|
||||||
|
'[detection] ',
|
||||||
|
'[identify] what is this ',
|
||||||
|
'[vqa] '
|
||||||
|
]
|
||||||
|
instruct_list = [
|
||||||
|
'**Hint:** Type in whatever you want',
|
||||||
|
'**Hint:** Send the command to generate a grounded image description',
|
||||||
|
'**Hint:** Type in a phrase about an object in the image and send the command',
|
||||||
|
'**Hint:** Type in a caption or phrase, and see object locations in the image',
|
||||||
|
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
|
||||||
|
'**Hint:** Send a question to get a short answer',
|
||||||
|
]
|
||||||
|
return prompt_list[idx], instruct_list[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class Chat:
|
||||||
|
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||||
|
self.device = device
|
||||||
|
self.model = model
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
||||||
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
|
|
||||||
|
def ask(self, text, conv):
|
||||||
|
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||||
|
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
||||||
|
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||||
|
else:
|
||||||
|
conv.append_message(conv.roles[0], text)
|
||||||
|
|
||||||
|
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
|
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
embs = self.get_context_emb(conv, img_list)
|
||||||
|
|
||||||
|
current_max_len = embs.shape[1] + max_new_tokens
|
||||||
|
if current_max_len - max_length > 0:
|
||||||
|
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||||
|
'The model will not see the contexts outside the range.')
|
||||||
|
begin_idx = max(0, current_max_len - max_length)
|
||||||
|
embs = embs[:, begin_idx:]
|
||||||
|
|
||||||
|
generation_kwargs = dict(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
stopping_criteria=self.stopping_criteria,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=True,
|
||||||
|
min_length=min_length,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
return generation_kwargs
|
||||||
|
|
||||||
|
def answer(self, conv, img_list, **kargs):
|
||||||
|
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||||
|
|
||||||
|
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||||
|
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
conv.messages[-1][1] = output_text
|
||||||
|
return output_text, output_token.cpu().numpy()
|
||||||
|
|
||||||
|
def stream_answer(self, conv, img_list, **kargs):
|
||||||
|
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||||
|
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||||
|
generation_kwargs['streamer'] = streamer
|
||||||
|
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
return streamer
|
||||||
|
|
||||||
|
def encode_img(self, img_list):
|
||||||
|
image = img_list[0]
|
||||||
|
img_list.pop(0)
|
||||||
|
if isinstance(image, str): # is a image path
|
||||||
|
raw_image = Image.open(image).convert('RGB')
|
||||||
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
raw_image = image
|
||||||
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
image_emb, _ = self.model.encode_img(image)
|
||||||
|
img_list.append(image_emb)
|
||||||
|
|
||||||
|
def upload_img(self, image, conv, img_list):
|
||||||
|
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||||
|
img_list.append(image)
|
||||||
|
msg = "Received."
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def get_context_emb(self, conv, img_list):
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
|
seg_tokens = [
|
||||||
|
self.model.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||||
|
# only add bos to the first seg
|
||||||
|
for i, seg in enumerate(prompt_segs)
|
||||||
|
]
|
||||||
|
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
return mixed_embs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
chat = Chat(model, vis_processor, device=device)
|
||||||
|
|
||||||
|
title = '**MiniGPT-v2 Demo**'
|
||||||
|
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
|
||||||
|
article = 'demo'
|
||||||
|
|
||||||
|
introduction = '''
|
||||||
|
For Abilities Involving Visual Grounding:
|
||||||
|
1. Grounding: CLICK **Send** to generate a grounded image description.
|
||||||
|
2. Refer: Input a referring object and CLICK **Send**.
|
||||||
|
3. Detection: Write a caption or phrase, and CLICK **Send**.
|
||||||
|
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
|
||||||
|
5. VQA: Input a visual question and CLICK **Send**.
|
||||||
|
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
|
||||||
|
|
||||||
|
You can also simply chat in free form!
|
||||||
|
'''
|
||||||
|
|
||||||
|
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
|
||||||
|
scale=8)
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(title)
|
||||||
|
gr.Markdown(description)
|
||||||
|
gr.Markdown(article)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=0.5):
|
||||||
|
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
|
||||||
|
|
||||||
|
temperature = gr.Slider(
|
||||||
|
minimum=0.1,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.1,
|
||||||
|
interactive=True,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
|
||||||
|
clear = gr.Button("Restart")
|
||||||
|
|
||||||
|
gr.Markdown(introduction)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
chat_state = gr.State(value=None)
|
||||||
|
img_list = gr.State(value=[])
|
||||||
|
chatbot = gr.Chatbot(label='MiniGPT-v2')
|
||||||
|
|
||||||
|
dataset = gr.Dataset(
|
||||||
|
components=[gr.Textbox(visible=False)],
|
||||||
|
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
|
||||||
|
type="index",
|
||||||
|
label='Task Shortcuts',
|
||||||
|
)
|
||||||
|
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
|
||||||
|
with gr.Row():
|
||||||
|
text_input.render()
|
||||||
|
send = gr.Button("Send", variant='primary', size='sm', scale=1)
|
||||||
|
|
||||||
|
upload_flag = gr.State(value=0)
|
||||||
|
replace_flag = gr.State(value=0)
|
||||||
|
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
gr.Examples(examples=[
|
||||||
|
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
|
||||||
|
img_list],
|
||||||
|
["examples_v2/sofa.jpg", "[detection] sofa", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
|
||||||
|
img_list],
|
||||||
|
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
|
||||||
|
replace_flag, img_list],
|
||||||
|
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||||
|
outputs=[upload_flag, replace_flag])
|
||||||
|
with gr.Column():
|
||||||
|
gr.Examples(examples=[
|
||||||
|
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
|
||||||
|
upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
|
||||||
|
replace_flag, img_list],
|
||||||
|
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||||
|
outputs=[upload_flag, replace_flag])
|
||||||
|
|
||||||
|
dataset.click(
|
||||||
|
gradio_taskselect,
|
||||||
|
inputs=[dataset],
|
||||||
|
outputs=[text_input, task_inst],
|
||||||
|
show_progress="hidden",
|
||||||
|
postprocess=False,
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input.submit(
|
||||||
|
gradio_ask,
|
||||||
|
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||||
|
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||||
|
).success(
|
||||||
|
gradio_stream_answer,
|
||||||
|
[chatbot, chat_state, img_list, temperature],
|
||||||
|
[chatbot, chat_state]
|
||||||
|
).success(
|
||||||
|
gradio_visualize,
|
||||||
|
[chatbot, image],
|
||||||
|
[chatbot],
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
send.click(
|
||||||
|
gradio_ask,
|
||||||
|
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||||
|
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||||
|
).success(
|
||||||
|
gradio_stream_answer,
|
||||||
|
[chatbot, chat_state, img_list, temperature],
|
||||||
|
[chatbot, chat_state]
|
||||||
|
).success(
|
||||||
|
gradio_visualize,
|
||||||
|
[chatbot, image],
|
||||||
|
[chatbot],
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
|
||||||
|
|
||||||
|
demo.launch(share=True, enable_queue=True)
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "</s>"
|
end_sym: "</s>"
|
||||||
|
25
eval_configs/minigptv2_eval.yaml
Normal file
25
eval_configs/minigptv2_eval.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt_v2
|
||||||
|
model_type: pretrain
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
low_resource: True
|
||||||
|
prompt_template: '[INST] {} [/INST]'
|
||||||
|
ckpt: '/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth'
|
||||||
|
llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update"
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc_sbu_align:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 448
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
|
|
||||||
# vit encoder
|
# vit encoder
|
||||||
image_size: 224
|
image_size: 224
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
|
|
||||||
# vit encoder
|
# vit encoder
|
||||||
image_size: 224
|
image_size: 224
|
||||||
|
31
minigpt4/configs/models/minigpt_v2.yaml
Executable file
31
minigpt4/configs/models/minigpt_v2.yaml
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt_v2
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 448
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
freeze_vit: True
|
||||||
|
|
||||||
|
# generation configs
|
||||||
|
prompt: ""
|
||||||
|
|
||||||
|
llama_model: "/path/to/llama2/weight"
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 448
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 448
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
@ -11,14 +11,18 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.models.base_model import BaseModel
|
from minigpt4.models.base_model import BaseModel
|
||||||
from minigpt4.models.minigpt_4 import MiniGPT4
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
|
from minigpt4.models.minigpt4 import MiniGPT4
|
||||||
|
from minigpt4.models.minigpt_v2 import MiniGPTv2
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
|
"MiniGPTBase",
|
||||||
"MiniGPT4",
|
"MiniGPT4",
|
||||||
|
"MiniGPTv2"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,8 +11,7 @@ from minigpt4.models.minigpt_base import MiniGPTBase
|
|||||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("minigpt4")
|
||||||
@registry.register_model("mini_gpt4")
|
|
||||||
class MiniGPT4(MiniGPTBase):
|
class MiniGPT4(MiniGPTBase):
|
||||||
"""
|
"""
|
||||||
MiniGPT-4 model
|
MiniGPT-4 model
|
@ -25,6 +25,8 @@ class MiniGPTBase(BaseModel):
|
|||||||
freeze_vit=True,
|
freeze_vit=True,
|
||||||
llama_model="",
|
llama_model="",
|
||||||
max_txt_len=32,
|
max_txt_len=32,
|
||||||
|
max_context_len=3800,
|
||||||
|
prompt_template="",
|
||||||
end_sym='\n',
|
end_sym='\n',
|
||||||
low_resource=False, # use 8 bit and put vit in cpu
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
@ -50,8 +52,10 @@ class MiniGPTBase(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
|
self.max_context_len = max_context_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
|
|
||||||
|
self.prompt_template = prompt_template
|
||||||
self.prompt_list = []
|
self.prompt_list = []
|
||||||
|
|
||||||
def vit_to_cpu(self):
|
def vit_to_cpu(self):
|
||||||
@ -129,7 +133,6 @@ class MiniGPTBase(BaseModel):
|
|||||||
wrapped_atts[i, :length] = 1
|
wrapped_atts[i, :length] = 1
|
||||||
return wrapped_embs, wrapped_atts
|
return wrapped_embs, wrapped_atts
|
||||||
|
|
||||||
|
|
||||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||||
"""
|
"""
|
||||||
Concatenate the batched input embedding and batched output embedding together.
|
Concatenate the batched input embedding and batched output embedding together.
|
||||||
@ -219,7 +222,7 @@ class MiniGPTBase(BaseModel):
|
|||||||
conv_q = [q.split(connect_sym)for q in conv_q]
|
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||||
conv_a = [a.split(connect_sym) for a in conv_a]
|
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||||
|
|
||||||
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
|
||||||
|
|
||||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
||||||
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
||||||
@ -233,7 +236,7 @@ class MiniGPTBase(BaseModel):
|
|||||||
instruction = None
|
instruction = None
|
||||||
|
|
||||||
if self.chat_template:
|
if self.chat_template:
|
||||||
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
|
||||||
|
|
||||||
if 'length' in samples:
|
if 'length' in samples:
|
||||||
# the input is a image train (like videos)
|
# the input is a image train (like videos)
|
||||||
|
139
minigpt4/models/minigpt_v2.py
Normal file
139
minigpt4/models/minigpt_v2.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.base_model import disabled_train
|
||||||
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("minigpt_v2")
|
||||||
|
class MiniGPTv2(MiniGPTBase):
|
||||||
|
"""
|
||||||
|
MiniGPT-v2 model
|
||||||
|
"""
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
|
"pretrain": "configs/models/minigpt_v2.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
img_size=448,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
llama_model="",
|
||||||
|
prompt_template='[INST] {} [/INST]',
|
||||||
|
max_txt_len=300,
|
||||||
|
end_sym='\n',
|
||||||
|
lora_r=64,
|
||||||
|
lora_target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
chat_template=False,
|
||||||
|
use_grad_checkpoint_llm=False,
|
||||||
|
max_context_len=3800,
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
end_sym=end_sym,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
low_resource=low_resource,
|
||||||
|
device_8bit=device_8bit,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
img_f_dim = self.visual_encoder.num_features * 4
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
img_f_dim, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
|
if use_grad_checkpoint_llm:
|
||||||
|
self.llama_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def encode_img(self, image):
|
||||||
|
device = image.device
|
||||||
|
|
||||||
|
if len(image.shape) > 4:
|
||||||
|
image = image.reshape(-1, *image.shape[-3:])
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
|
image_embeds = image_embeds[:, 1:, :]
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds)
|
||||||
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg):
|
||||||
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
|
img_size = cfg.get("image_size")
|
||||||
|
llama_model = cfg.get("llama_model")
|
||||||
|
|
||||||
|
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||||
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
|
||||||
|
prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
|
||||||
|
max_txt_len = cfg.get("max_txt_len", 300)
|
||||||
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
lora_r = cfg.get("lora_r", 64)
|
||||||
|
lora_alpha = cfg.get("lora_alpha", 16)
|
||||||
|
chat_template = cfg.get("chat_template", False)
|
||||||
|
|
||||||
|
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||||
|
max_context_len = cfg.get("max_context_len", 3800)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
low_resource=low_resource,
|
||||||
|
end_sym=end_sym,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
chat_template=chat_template,
|
||||||
|
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
if ckpt_path:
|
||||||
|
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
|
return model
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
|
Loading…
Reference in New Issue
Block a user