diff --git a/demo_v2.py b/demo_v2.py
new file mode 100644
index 0000000..1ea87fc
--- /dev/null
+++ b/demo_v2.py
@@ -0,0 +1,765 @@
+import argparse
+import os
+import random
+import requests
+from io import BytesIO
+from threading import Thread
+from collections import defaultdict
+
+import cv2
+from termcolor import colored
+from textwrap import wrap
+from torchvision.transforms import functional as F
+import re
+
+import numpy as np
+from PIL import Image
+import torch
+import torch.backends.cudnn as cudnn
+import html
+import gradio as gr
+from transformers import TextIteratorStreamer
+
+
+import minigpt4.tasks as tasks
+from minigpt4.common.config import Config
+from minigpt4.common.dist_utils import get_rank, init_distributed_mode
+from minigpt4.common.logger import setup_logger
+from minigpt4.common.optims import (
+ LinearWarmupCosineLRScheduler,
+ LinearWarmupStepLRScheduler,
+)
+from minigpt4.common.registry import registry
+from minigpt4.common.utils import now
+from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
+
+# imports modules for registration
+from minigpt4.datasets.builders import *
+from minigpt4.models import *
+from minigpt4.processors import *
+from minigpt4.runners import *
+from minigpt4.tasks import *
+
+parser = argparse.ArgumentParser(description="Demo")
+parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
+parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+)
+
+import torch.backends.cudnn as cudnn
+
+random.seed(42)
+np.random.seed(42)
+torch.manual_seed(42)
+
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+print('Initializing Chat')
+cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
+cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
+cfg.model_cfg.lora_r = 64
+cfg.model_cfg.lora_alpha = 16
+
+device = 'cuda'
+
+model_config = cfg.model_cfg
+model_cls = registry.get_model_class(model_config.arch)
+model = model_cls.from_config(model_config).to(device)
+bounding_box_size = 100
+
+vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
+vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+
+model = model.eval()
+
+CONV_VISION = Conversation(
+ system="",
+ roles=(r"[INST] ", r" [/INST]"),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="",
+)
+
+
+def extract_substrings(string):
+ # first check if there is no-finished bracket
+ index = string.rfind('}')
+ if index != -1:
+ string = string[:index + 1]
+
+ pattern = r'
(.*?)\}(?!<)'
+ matches = re.findall(pattern, string)
+ substrings = [match for match in matches]
+
+ return substrings
+
+
+def is_overlapping(rect1, rect2):
+ x1, y1, x2, y2 = rect1
+ x3, y3, x4, y4 = rect2
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
+
+
+def computeIoU(bbox1, bbox2):
+ x1, y1, x2, y2 = bbox1
+ x3, y3, x4, y4 = bbox2
+ intersection_x1 = max(x1, x3)
+ intersection_y1 = max(y1, y3)
+ intersection_x2 = min(x2, x4)
+ intersection_y2 = min(y2, y4)
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
+ union_area = bbox1_area + bbox2_area - intersection_area
+ iou = intersection_area / union_area
+ return iou
+
+
+def save_tmp_img(visual_img):
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
+ file_path = "/tmp/" + file_name
+ visual_img.save(file_path)
+ return file_path
+
+
+def mask2bbox(mask):
+ if mask is None:
+ return ''
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
+ mask = np.array(mask)[:, :, 0]
+
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if rows.sum():
+ # Get the top, bottom, left, and right boundaries
+ rmin, rmax = np.where(rows)[0][[0, -1]]
+ cmin, cmax = np.where(cols)[0][[0, -1]]
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
+ else:
+ bbox = ''
+
+ return bbox
+
+
+def escape_markdown(text):
+ # List of Markdown special characters that need to be escaped
+ md_chars = ['<', '>']
+
+ # Escape each special character
+ for char in md_chars:
+ text = text.replace(char, '\\' + char)
+
+ return text
+
+
+def reverse_escape(text):
+ md_chars = ['\\<', '\\>']
+
+ for char in md_chars:
+ text = text.replace(char, char[1:])
+
+ return text
+
+
+colors = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (210, 210, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (114, 128, 250),
+ (0, 165, 255),
+ (0, 128, 0),
+ (144, 238, 144),
+ (238, 238, 175),
+ (255, 191, 0),
+ (0, 128, 0),
+ (226, 43, 138),
+ (255, 0, 255),
+ (0, 215, 255),
+]
+
+color_map = {
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
+ color_id, color in enumerate(colors)
+}
+
+used_colors = colors
+
+
+def visualize_all_bbox_together(image, generation):
+ if image is None:
+ return None, ''
+
+ generation = html.unescape(generation)
+ print('gen begin', generation)
+
+ image_width, image_height = image.size
+ image = image.resize([500, int(500 / image_width * image_height)])
+ image_width, image_height = image.size
+
+ string_list = extract_substrings(generation)
+ if string_list: # it is grounding or detection
+ mode = 'all'
+ entities = defaultdict(list)
+ i = 0
+ j = 0
+ for string in string_list:
+ try:
+ obj, string = string.split('
')
+ except ValueError:
+ print('wrong string: ', string)
+ continue
+ bbox_list = string.split('')
+ flag = False
+ for bbox_string in bbox_list:
+ integers = re.findall(r'-?\d+', bbox_string)
+ if len(integers) == 4:
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+
+ entities[obj].append([left, bottom, right, top])
+
+ j += 1
+ flag = True
+ if flag:
+ i += 1
+ else:
+ integers = re.findall(r'-?\d+', generation)
+
+ if len(integers) == 4: # it is refer
+ mode = 'single'
+
+ entities = list()
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+ entities.append([left, bottom, right, top])
+ else:
+ # don't detect any valid bbox to visualize
+ return None, ''
+
+ if len(entities) == 0:
+ return None, ''
+
+ if isinstance(image, Image.Image):
+ image_h = image.height
+ image_w = image.width
+ image = np.array(image)
+
+ elif isinstance(image, str):
+ if os.path.exists(image):
+ pil_img = Image.open(image).convert("RGB")
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ image_h = pil_img.height
+ image_w = pil_img.width
+ else:
+ raise ValueError(f"invaild image path, {image}")
+ elif isinstance(image, torch.Tensor):
+
+ image_tensor = image.cpu()
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
+ pil_img = T.ToPILImage()(image_tensor)
+ image_h = pil_img.height
+ image_w = pil_img.width
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ else:
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
+
+ indices = list(range(len(entities)))
+
+ new_image = image.copy()
+
+ previous_bboxes = []
+ # size of text
+ text_size = 0.5
+ # thickness of text
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
+ box_line = 2
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
+ base_height = int(text_height * 0.675)
+ text_offset_original = text_height - base_height
+ text_spaces = 2
+
+ # num_bboxes = sum(len(x[-1]) for x in entities)
+ used_colors = colors # random.sample(colors, k=num_bboxes)
+
+ color_id = -1
+ for entity_idx, entity_name in enumerate(entities):
+ if mode == 'single' or mode == 'identify':
+ bboxes = entity_name
+ bboxes = [bboxes]
+ else:
+ bboxes = entities[entity_name]
+ color_id += 1
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
+ skip_flag = False
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
+
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
+
+ if mode == 'all':
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
+
+ x1 = orig_x1 - l_o
+ y1 = orig_y1 - l_o
+
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
+ x1 = orig_x1 + r_o
+
+ # add text background
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
+ text_line)
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
+
+ for prev_bbox in previous_bboxes:
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
+ prev_bbox['phrase'] == entity_name:
+ skip_flag = True
+ break
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
+
+ if text_bg_y2 >= image_h:
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
+ text_bg_y2 = image_h
+ y1 = image_h
+ break
+ if not skip_flag:
+ alpha = 0.5
+ for i in range(text_bg_y1, text_bg_y2):
+ for j in range(text_bg_x1, text_bg_x2):
+ if i < image_h and j < image_w:
+ if j < text_bg_x1 + 1.35 * c_width:
+ # original color
+ bg_color = color
+ else:
+ # white
+ bg_color = [255, 255, 255]
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
+ np.uint8)
+
+ cv2.putText(
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
+ )
+
+ previous_bboxes.append(
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
+
+ if mode == 'all':
+ def color_iterator(colors):
+ while True:
+ for color in colors:
+ yield color
+
+ color_gen = color_iterator(colors)
+
+ # Add colors to phrases and remove
+ def colored_phrases(match):
+ phrase = match.group(1)
+ color = next(color_gen)
+ return f'{phrase}'
+
+ print('gen before', generation)
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation)
+ print('gen after', generation)
+ generation_colored = re.sub(r'(.*?)
', colored_phrases, generation)
+ else:
+ generation_colored = ''
+
+ pil_image = Image.fromarray(new_image)
+ return pil_image, generation_colored
+
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
+ interactive=True), chat_state, img_list
+
+
+def image_upload_trigger(upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ print('flag', upload_flag, replace_flag)
+ print("SET UPLOAD FLAG!")
+ upload_flag = 1
+ if img_list:
+ print("SET REPLACE FLAG!")
+ replace_flag = 1
+ print('flag', upload_flag, replace_flag)
+ return upload_flag, replace_flag
+
+
+def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ print('flag', upload_flag, replace_flag)
+ print("SET UPLOAD FLAG!")
+ upload_flag = 1
+ if img_list or replace_flag == 1:
+ print("SET REPLACE FLAG!")
+ replace_flag = 1
+
+ print('flag', upload_flag, replace_flag)
+ return upload_flag, replace_flag
+
+
+def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+ else:
+ mask = None
+
+ if '[identify]' in user_message:
+ # check if user provide bbox in the text input
+ integers = re.findall(r'-?\d+', user_message)
+ if len(integers) != 4: # no bbox in text
+ bbox = mask2bbox(mask)
+ user_message = user_message + bbox
+
+ if len(user_message) == 0:
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
+
+ if chat_state is None:
+ chat_state = CONV_VISION.copy()
+
+ print('upload flag: {}'.format(upload_flag))
+ if upload_flag:
+ if replace_flag:
+ print('RESET!!!!!!!')
+ chat_state = CONV_VISION.copy() # new image, reset everything
+ replace_flag = 0
+ chatbot = []
+ print('UPLOAD IMAGE!!')
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ upload_flag = 0
+
+ chat.ask(user_message, chat_state)
+
+ chatbot = chatbot + [[user_message, None]]
+
+ if '[identify]' in user_message:
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
+ if visual_img is not None:
+ print('Visualizing the input')
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[(file_path,), None]]
+
+ return '', chatbot, chat_state, img_list, upload_flag, replace_flag
+
+
+def gradio_answer(chatbot, chat_state, img_list, temperature):
+ llm_message = chat.answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)[0]
+ chatbot[-1][1] = llm_message
+ return chatbot, chat_state
+
+
+def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
+ if not isinstance(img_list[0], torch.Tensor):
+ chat.encode_img(img_list)
+ streamer = chat.stream_answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)
+ output = ''
+ for new_output in streamer:
+ escapped = escape_markdown(new_output)
+ output += escapped
+ chatbot[-1][1] = output
+ yield chatbot, chat_state
+ # print('message: ', chat_state.messages)
+ chat_state.messages[-1][1] = reverse_escape(output) + ''
+ return chatbot, chat_state
+
+
+def gradio_visualize(chatbot, gr_img):
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+
+ unescaped = reverse_escape(chatbot[-1][1])
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
+ if visual_img is not None:
+ print('Visualizing the output')
+ if len(generation_color):
+ chatbot[-1][1] = generation_color
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[None, (file_path,)]]
+
+ return chatbot
+
+
+def gradio_taskselect(idx):
+ prompt_list = [
+ '',
+ '[grounding] describe this image in detail',
+ '[refer] ',
+ '[detection] ',
+ '[identify] what is this ',
+ '[vqa] '
+ ]
+ instruct_list = [
+ '**Hint:** Type in whatever you want',
+ '**Hint:** Send the command to generate a grounded image description',
+ '**Hint:** Type in a phrase about an object in the image and send the command',
+ '**Hint:** Type in a caption or phrase, and see object locations in the image',
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
+ '**Hint:** Send a question to get a short answer',
+ ]
+ return prompt_list[idx], instruct_list[idx]
+
+
+class Chat:
+ def __init__(self, model, vis_processor, device='cuda:0'):
+ self.device = device
+ self.model = model
+ self.vis_processor = vis_processor
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+
+ def ask(self, text, conv):
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
+ and conv.messages[-1][1][-6:] == '': # last message is image.
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
+ else:
+ conv.append_message(conv.roles[0], text)
+
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
+ conv.append_message(conv.roles[1], None)
+ embs = self.get_context_emb(conv, img_list)
+
+ current_max_len = embs.shape[1] + max_new_tokens
+ if current_max_len - max_length > 0:
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
+ 'The model will not see the contexts outside the range.')
+ begin_idx = max(0, current_max_len - max_length)
+ embs = embs[:, begin_idx:]
+
+ generation_kwargs = dict(
+ inputs_embeds=embs,
+ max_new_tokens=max_new_tokens,
+ stopping_criteria=self.stopping_criteria,
+ num_beams=num_beams,
+ do_sample=True,
+ min_length=min_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ temperature=temperature,
+ )
+ return generation_kwargs
+
+ def answer(self, conv, img_list, **kargs):
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
+
+ output_token = self.model.llama_model.generate(**generation_dict)[0]
+ output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
+ conv.messages[-1][1] = output_text
+ return output_text, output_token.cpu().numpy()
+
+ def stream_answer(self, conv, img_list, **kargs):
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
+ streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
+ generation_kwargs['streamer'] = streamer
+ thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
+ thread.start()
+ return streamer
+
+ def encode_img(self, img_list):
+ image = img_list[0]
+ img_list.pop(0)
+ if isinstance(image, str): # is a image path
+ raw_image = Image.open(image).convert('RGB')
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
+ elif isinstance(image, Image.Image):
+ raw_image = image
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
+ elif isinstance(image, torch.Tensor):
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ image = image.to(self.device)
+
+ image_emb, _ = self.model.encode_img(image)
+ img_list.append(image_emb)
+
+ def upload_img(self, image, conv, img_list):
+ conv.append_message(conv.roles[0], "
")
+ img_list.append(image)
+ msg = "Received."
+
+ return msg
+
+ def get_context_emb(self, conv, img_list):
+ prompt = conv.get_prompt()
+ prompt_segs = prompt.split('')
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
+ seg_tokens = [
+ self.model.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
+ # only add bos to the first seg
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+
+
+chat = Chat(model, vis_processor, device=device)
+
+title = '**MiniGPT-v2 Demo**'
+description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
+article = 'demo'
+
+introduction = '''
+For Abilities Involving Visual Grounding:
+1. Grounding: CLICK **Send** to generate a grounded image description.
+2. Refer: Input a referring object and CLICK **Send**.
+3. Detection: Write a caption or phrase, and CLICK **Send**.
+4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
+5. VQA: Input a visual question and CLICK **Send**.
+6. No Tag: Input whatever you want and CLICK **Send** without any tagging
+
+You can also simply chat in free form!
+'''
+
+text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
+ scale=8)
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ gr.Markdown(description)
+ gr.Markdown(article)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=2.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ clear = gr.Button("Restart")
+
+ gr.Markdown(introduction)
+
+ with gr.Column():
+ chat_state = gr.State(value=None)
+ img_list = gr.State(value=[])
+ chatbot = gr.Chatbot(label='MiniGPT-v2')
+
+ dataset = gr.Dataset(
+ components=[gr.Textbox(visible=False)],
+ samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
+ type="index",
+ label='Task Shortcuts',
+ )
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
+ with gr.Row():
+ text_input.render()
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
+
+ upload_flag = gr.State(value=0)
+ replace_flag = gr.State(value=0)
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
+
+ with gr.Row():
+ with gr.Column():
+ gr.Examples(examples=[
+ ["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
+ img_list],
+ ["examples_v2/sofa.jpg", "[detection] sofa", upload_flag, replace_flag, img_list],
+ ["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
+ img_list],
+ ["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+ with gr.Column():
+ gr.Examples(examples=[
+ ["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
+ upload_flag, replace_flag, img_list],
+ ["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
+ ["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
+ ["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+
+ dataset.click(
+ gradio_taskselect,
+ inputs=[dataset],
+ outputs=[text_input, task_inst],
+ show_progress="hidden",
+ postprocess=False,
+ queue=False,
+ )
+
+ text_input.submit(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ send.click(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
+
+demo.launch(share=True, enable_queue=True)
diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml
index 15af837..1828cef 100644
--- a/eval_configs/minigpt4_eval.yaml
+++ b/eval_configs/minigpt4_eval.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml
index 62709e1..f743497 100644
--- a/eval_configs/minigpt4_llama2_eval.yaml
+++ b/eval_configs/minigpt4_llama2_eval.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: ""
diff --git a/eval_configs/minigptv2_eval.yaml b/eval_configs/minigptv2_eval.yaml
new file mode 100644
index 0000000..0980178
--- /dev/null
+++ b/eval_configs/minigptv2_eval.yaml
@@ -0,0 +1,25 @@
+model:
+ arch: minigpt_v2
+ model_type: pretrain
+ max_txt_len: 160
+ end_sym: ""
+ low_resource: True
+ prompt_template: '[INST] {} [/INST]'
+ ckpt: '/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth'
+ llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update"
+ lora_r: 64
+ lora_alpha: 16
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_eval"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: image_text_pretrain
diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml
index c201bdc..af73f1d 100644
--- a/minigpt4/configs/models/minigpt4_llama2.yaml
+++ b/minigpt4/configs/models/minigpt4_llama2.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
# vit encoder
image_size: 224
diff --git a/minigpt4/configs/models/minigpt4_vicuna0.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml
index 34bd2ed..783ec4e 100644
--- a/minigpt4/configs/models/minigpt4_vicuna0.yaml
+++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
# vit encoder
image_size: 224
diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt_v2.yaml
new file mode 100755
index 0000000..018a18a
--- /dev/null
+++ b/minigpt4/configs/models/minigpt_v2.yaml
@@ -0,0 +1,31 @@
+model:
+ arch: minigpt_v2
+
+ # vit encoder
+ image_size: 448
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+
+ # generation configs
+ prompt: ""
+
+ llama_model: "/path/to/llama2/weight"
+ lora_r: 64
+ lora_alpha: 16
+
+
+preprocess:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ eval:
+ name: "blip2_image_eval"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py
index 74626be..bc01b56 100644
--- a/minigpt4/models/__init__.py
+++ b/minigpt4/models/__init__.py
@@ -11,14 +11,18 @@ from omegaconf import OmegaConf
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
-from minigpt4.models.minigpt_4 import MiniGPT4
+from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.minigpt4 import MiniGPT4
+from minigpt4.models.minigpt_v2 import MiniGPTv2
from minigpt4.processors.base_processor import BaseProcessor
__all__ = [
"load_model",
"BaseModel",
+ "MiniGPTBase",
"MiniGPT4",
+ "MiniGPTv2"
]
diff --git a/minigpt4/models/minigpt_4.py b/minigpt4/models/minigpt4.py
similarity index 99%
rename from minigpt4/models/minigpt_4.py
rename to minigpt4/models/minigpt4.py
index f73a800..a2e4798 100644
--- a/minigpt4/models/minigpt_4.py
+++ b/minigpt4/models/minigpt4.py
@@ -11,8 +11,7 @@ from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
-
-@registry.register_model("mini_gpt4")
+@registry.register_model("minigpt4")
class MiniGPT4(MiniGPTBase):
"""
MiniGPT-4 model
diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py
index aa91bf3..77c919a 100644
--- a/minigpt4/models/minigpt_base.py
+++ b/minigpt4/models/minigpt_base.py
@@ -25,10 +25,12 @@ class MiniGPTBase(BaseModel):
freeze_vit=True,
llama_model="",
max_txt_len=32,
+ max_context_len=3800,
+ prompt_template="",
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
- lora_r=0, # lora_r means lora is not used
+ lora_r=0, # lora_r means lora is not used
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
@@ -50,8 +52,10 @@ class MiniGPTBase(BaseModel):
)
self.max_txt_len = max_txt_len
+ self.max_context_len = max_context_len
self.end_sym = end_sym
+ self.prompt_template = prompt_template
self.prompt_list = []
def vit_to_cpu(self):
@@ -129,7 +133,6 @@ class MiniGPTBase(BaseModel):
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
-
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
@@ -219,7 +222,7 @@ class MiniGPTBase(BaseModel):
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
- conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
+ conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
@@ -233,7 +236,7 @@ class MiniGPTBase(BaseModel):
instruction = None
if self.chat_template:
- instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
+ instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py
new file mode 100644
index 0000000..a046b0b
--- /dev/null
+++ b/minigpt4/models/minigpt_v2.py
@@ -0,0 +1,139 @@
+import logging
+import random
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+
+from minigpt4.common.registry import registry
+from minigpt4.models.base_model import disabled_train
+from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
+
+
+@registry.register_model("minigpt_v2")
+class MiniGPTv2(MiniGPTBase):
+ """
+ MiniGPT-v2 model
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain": "configs/models/minigpt_v2.yaml",
+ }
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ img_size=448,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ llama_model="",
+ prompt_template='[INST] {} [/INST]',
+ max_txt_len=300,
+ end_sym='\n',
+ lora_r=64,
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.05,
+ chat_template=False,
+ use_grad_checkpoint_llm=False,
+ max_context_len=3800,
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ ):
+ super().__init__(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ llama_model=llama_model,
+ max_txt_len=max_txt_len,
+ max_context_len=max_context_len,
+ end_sym=end_sym,
+ prompt_template=prompt_template,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ lora_r=lora_r,
+ lora_target_modules=lora_target_modules,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ img_f_dim = self.visual_encoder.num_features * 4
+ self.llama_proj = nn.Linear(
+ img_f_dim, self.llama_model.config.hidden_size
+ )
+ self.chat_template = chat_template
+
+ if use_grad_checkpoint_llm:
+ self.llama_model.gradient_checkpointing_enable()
+
+ def encode_img(self, image):
+ device = image.device
+
+ if len(image.shape) > 4:
+ image = image.reshape(-1, *image.shape[-3:])
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
+ image_embeds = image_embeds[:, 1:, :]
+ bs, pn, hs = image_embeds.shape
+ image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
+
+ inputs_llama = self.llama_proj(image_embeds)
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
+ return inputs_llama, atts_llama
+
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ img_size = cfg.get("image_size")
+ llama_model = cfg.get("llama_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ low_resource = cfg.get("low_resource", False)
+
+ prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
+ max_txt_len = cfg.get("max_txt_len", 300)
+ end_sym = cfg.get("end_sym", '\n')
+
+ lora_r = cfg.get("lora_r", 64)
+ lora_alpha = cfg.get("lora_alpha", 16)
+ chat_template = cfg.get("chat_template", False)
+
+ use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
+ max_context_len = cfg.get("max_context_len", 3800)
+
+ model = cls(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ llama_model=llama_model,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ low_resource=low_resource,
+ end_sym=end_sym,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ chat_template=chat_template,
+ use_grad_checkpoint_llm=use_grad_checkpoint_llm,
+ max_context_len=max_context_len,
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(ckpt['model'], strict=False)
+
+ return model
diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
index 6920aab..f3981b8 100644
--- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_llama2
diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml
index 9a6ac2d..fa2b578 100644
--- a/train_configs/minigpt4_llama2_stage2_finetune.yaml
+++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml
index 4ec1597..be87b77 100644
--- a/train_configs/minigpt4_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_stage1_pretrain.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_vicuna0
diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml
index 54cedb4..404dfd6 100644
--- a/train_configs/minigpt4_stage2_finetune.yaml
+++ b/train_configs/minigpt4_stage2_finetune.yaml
@@ -1,5 +1,5 @@
model:
- arch: mini_gpt4
+ arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160