diff --git a/demo_v2.py b/demo_v2.py new file mode 100644 index 0000000..1ea87fc --- /dev/null +++ b/demo_v2.py @@ -0,0 +1,765 @@ +import argparse +import os +import random +import requests +from io import BytesIO +from threading import Thread +from collections import defaultdict + +import cv2 +from termcolor import colored +from textwrap import wrap +from torchvision.transforms import functional as F +import re + +import numpy as np +from PIL import Image +import torch +import torch.backends.cudnn as cudnn +import html +import gradio as gr +from transformers import TextIteratorStreamer + + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import now +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + +parser = argparse.ArgumentParser(description="Demo") +parser.add_argument("--cfg-path", required=True, help="path to configuration file.") +parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", +) + +import torch.backends.cudnn as cudnn + +random.seed(42) +np.random.seed(42) +torch.manual_seed(42) + +cudnn.benchmark = False +cudnn.deterministic = True + +print('Initializing Chat') +cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml'])) +cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth" +cfg.model_cfg.lora_r = 64 +cfg.model_cfg.lora_alpha = 16 + +device = 'cuda' + +model_config = cfg.model_cfg +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to(device) +bounding_box_size = 100 + +vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + +model = model.eval() + +CONV_VISION = Conversation( + system="", + roles=(r"[INST] ", r" [/INST]"), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + + +def extract_substrings(string): + # first check if there is no-finished bracket + index = string.rfind('}') + if index != -1: + string = string[:index + 1] + + pattern = r'

(.*?)\}(?!<)' + matches = re.findall(pattern, string) + substrings = [match for match in matches] + + return substrings + + +def is_overlapping(rect1, rect2): + x1, y1, x2, y2 = rect1 + x3, y3, x4, y4 = rect2 + return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) + + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou + + +def save_tmp_img(visual_img): + file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg" + file_path = "/tmp/" + file_name + visual_img.save(file_path) + return file_path + + +def mask2bbox(mask): + if mask is None: + return '' + mask = mask.resize([100, 100], resample=Image.NEAREST) + mask = np.array(mask)[:, :, 0] + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if rows.sum(): + # Get the top, bottom, left, and right boundaries + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax) + else: + bbox = '' + + return bbox + + +def escape_markdown(text): + # List of Markdown special characters that need to be escaped + md_chars = ['<', '>'] + + # Escape each special character + for char in md_chars: + text = text.replace(char, '\\' + char) + + return text + + +def reverse_escape(text): + md_chars = ['\\<', '\\>'] + + for char in md_chars: + text = text.replace(char, char[1:]) + + return text + + +colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (210, 210, 0), + (255, 0, 255), + (0, 255, 255), + (114, 128, 250), + (0, 165, 255), + (0, 128, 0), + (144, 238, 144), + (238, 238, 175), + (255, 191, 0), + (0, 128, 0), + (226, 43, 138), + (255, 0, 255), + (0, 215, 255), +] + +color_map = { + f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for + color_id, color in enumerate(colors) +} + +used_colors = colors + + +def visualize_all_bbox_together(image, generation): + if image is None: + return None, '' + + generation = html.unescape(generation) + print('gen begin', generation) + + image_width, image_height = image.size + image = image.resize([500, int(500 / image_width * image_height)]) + image_width, image_height = image.size + + string_list = extract_substrings(generation) + if string_list: # it is grounding or detection + mode = 'all' + entities = defaultdict(list) + i = 0 + j = 0 + for string in string_list: + try: + obj, string = string.split('

') + except ValueError: + print('wrong string: ', string) + continue + bbox_list = string.split('') + flag = False + for bbox_string in bbox_list: + integers = re.findall(r'-?\d+', bbox_string) + if len(integers) == 4: + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + + entities[obj].append([left, bottom, right, top]) + + j += 1 + flag = True + if flag: + i += 1 + else: + integers = re.findall(r'-?\d+', generation) + + if len(integers) == 4: # it is refer + mode = 'single' + + entities = list() + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + entities.append([left, bottom, right, top]) + else: + # don't detect any valid bbox to visualize + return None, '' + + if len(entities) == 0: + return None, '' + + if isinstance(image, Image.Image): + image_h = image.height + image_w = image.width + image = np.array(image) + + elif isinstance(image, str): + if os.path.exists(image): + pil_img = Image.open(image).convert("RGB") + image = np.array(pil_img)[:, :, [2, 1, 0]] + image_h = pil_img.height + image_w = pil_img.width + else: + raise ValueError(f"invaild image path, {image}") + elif isinstance(image, torch.Tensor): + + image_tensor = image.cpu() + reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None] + reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None] + image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean + pil_img = T.ToPILImage()(image_tensor) + image_h = pil_img.height + image_w = pil_img.width + image = np.array(pil_img)[:, :, [2, 1, 0]] + else: + raise ValueError(f"invaild image format, {type(image)} for {image}") + + indices = list(range(len(entities))) + + new_image = image.copy() + + previous_bboxes = [] + # size of text + text_size = 0.5 + # thickness of text + text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) + box_line = 2 + (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line) + base_height = int(text_height * 0.675) + text_offset_original = text_height - base_height + text_spaces = 2 + + # num_bboxes = sum(len(x[-1]) for x in entities) + used_colors = colors # random.sample(colors, k=num_bboxes) + + color_id = -1 + for entity_idx, entity_name in enumerate(entities): + if mode == 'single' or mode == 'identify': + bboxes = entity_name + bboxes = [bboxes] + else: + bboxes = entities[entity_name] + color_id += 1 + for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes): + skip_flag = False + orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm) + + color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist()) + new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line) + + if mode == 'all': + l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1 + + x1 = orig_x1 - l_o + y1 = orig_y1 - l_o + + if y1 < text_height + text_offset_original + 2 * text_spaces: + y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces + x1 = orig_x1 + r_o + + # add text background + (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, + text_line) + text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - ( + text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1 + + for prev_bbox in previous_bboxes: + if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \ + prev_bbox['phrase'] == entity_name: + skip_flag = True + break + while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']): + text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces) + text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces) + y1 += (text_height + text_offset_original + 2 * text_spaces) + + if text_bg_y2 >= image_h: + text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces)) + text_bg_y2 = image_h + y1 = image_h + break + if not skip_flag: + alpha = 0.5 + for i in range(text_bg_y1, text_bg_y2): + for j in range(text_bg_x1, text_bg_x2): + if i < image_h and j < image_w: + if j < text_bg_x1 + 1.35 * c_width: + # original color + bg_color = color + else: + # white + bg_color = [255, 255, 255] + new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype( + np.uint8) + + cv2.putText( + new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), + cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA + ) + + previous_bboxes.append( + {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name}) + + if mode == 'all': + def color_iterator(colors): + while True: + for color in colors: + yield color + + color_gen = color_iterator(colors) + + # Add colors to phrases and remove

+ def colored_phrases(match): + phrase = match.group(1) + color = next(color_gen) + return f'{phrase}' + + print('gen before', generation) + generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation) + print('gen after', generation) + generation_colored = re.sub(r'

(.*?)

', colored_phrases, generation) + else: + generation_colored = '' + + pil_image = Image.fromarray(new_image) + return pil_image, generation_colored + + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat', + interactive=True), chat_state, img_list + + +def image_upload_trigger(upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list: + print("SET REPLACE FLAG!") + replace_flag = 1 + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def example_trigger(text_input, image, upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list or replace_flag == 1: + print("SET REPLACE FLAG!") + replace_flag = 1 + + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + else: + mask = None + + if '[identify]' in user_message: + # check if user provide bbox in the text input + integers = re.findall(r'-?\d+', user_message) + if len(integers) != 4: # no bbox in text + bbox = mask2bbox(mask) + user_message = user_message + bbox + + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + + if chat_state is None: + chat_state = CONV_VISION.copy() + + print('upload flag: {}'.format(upload_flag)) + if upload_flag: + if replace_flag: + print('RESET!!!!!!!') + chat_state = CONV_VISION.copy() # new image, reset everything + replace_flag = 0 + chatbot = [] + print('UPLOAD IMAGE!!') + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + upload_flag = 0 + + chat.ask(user_message, chat_state) + + chatbot = chatbot + [[user_message, None]] + + if '[identify]' in user_message: + visual_img, _ = visualize_all_bbox_together(gr_img, user_message) + if visual_img is not None: + print('Visualizing the input') + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[(file_path,), None]] + + return '', chatbot, chat_state, img_list, upload_flag, replace_flag + + +def gradio_answer(chatbot, chat_state, img_list, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000)[0] + chatbot[-1][1] = llm_message + return chatbot, chat_state + + +def gradio_stream_answer(chatbot, chat_state, img_list, temperature): + if not isinstance(img_list[0], torch.Tensor): + chat.encode_img(img_list) + streamer = chat.stream_answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000) + output = '' + for new_output in streamer: + escapped = escape_markdown(new_output) + output += escapped + chatbot[-1][1] = output + yield chatbot, chat_state + # print('message: ', chat_state.messages) + chat_state.messages[-1][1] = reverse_escape(output) + '' + return chatbot, chat_state + + +def gradio_visualize(chatbot, gr_img): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + + unescaped = reverse_escape(chatbot[-1][1]) + visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped) + if visual_img is not None: + print('Visualizing the output') + if len(generation_color): + chatbot[-1][1] = generation_color + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[None, (file_path,)]] + + return chatbot + + +def gradio_taskselect(idx): + prompt_list = [ + '', + '[grounding] describe this image in detail', + '[refer] ', + '[detection] ', + '[identify] what is this ', + '[vqa] ' + ] + instruct_list = [ + '**Hint:** Type in whatever you want', + '**Hint:** Send the command to generate a grounded image description', + '**Hint:** Type in a phrase about an object in the image and send the command', + '**Hint:** Type in a caption or phrase, and see object locations in the image', + '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw', + '**Hint:** Send a question to get a short answer', + ] + return prompt_list[idx], instruct_list[idx] + + +class Chat: + def __init__(self, model, vis_processor, device='cuda:0'): + self.device = device + self.model = model + self.vis_processor = vis_processor + stop_words_ids = [torch.tensor([2]).to(self.device)] + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ + and conv.messages[-1][1][-6:] == '': # last message is image. + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + embs = embs[:, begin_idx:] + + generation_kwargs = dict( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + return generation_kwargs + + def answer(self, conv, img_list, **kargs): + generation_kwargs = self.answer_prepare(conv, img_list, **kargs) + + output_token = self.model.llama_model.generate(**generation_dict)[0] + output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) + conv.messages[-1][1] = output_text + return output_text, output_token.cpu().numpy() + + def stream_answer(self, conv, img_list, **kargs): + generation_kwargs = self.answer_prepare(conv, img_list, **kargs) + streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) + generation_kwargs['streamer'] = streamer + thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + + def encode_img(self, img_list): + image = img_list[0] + img_list.pop(0) + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, Image.Image): + raw_image = image + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = image.to(self.device) + + image_emb, _ = self.model.encode_img(image) + img_list.append(image_emb) + + def upload_img(self, image, conv, img_list): + conv.append_message(conv.roles[0], "") + img_list.append(image) + msg = "Received." + + return msg + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + + +chat = Chat(model, vis_processor, device=device) + +title = '**MiniGPT-v2 Demo**' +description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!' +article = 'demo' + +introduction = ''' +For Abilities Involving Visual Grounding: +1. Grounding: CLICK **Send** to generate a grounded image description. +2. Refer: Input a referring object and CLICK **Send**. +3. Detection: Write a caption or phrase, and CLICK **Send**. +4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time). +5. VQA: Input a visual question and CLICK **Send**. +6. No Tag: Input whatever you want and CLICK **Send** without any tagging + +You can also simply chat in free form! +''' + +text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False, + scale=8) +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown(article) + + with gr.Row(): + with gr.Column(scale=0.5): + image = gr.Image(type="pil", tool='sketch', brush_radius=20) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + clear = gr.Button("Restart") + + gr.Markdown(introduction) + + with gr.Column(): + chat_state = gr.State(value=None) + img_list = gr.State(value=[]) + chatbot = gr.Chatbot(label='MiniGPT-v2') + + dataset = gr.Dataset( + components=[gr.Textbox(visible=False)], + samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']], + type="index", + label='Task Shortcuts', + ) + task_inst = gr.Markdown('**Hint:** Upload your image and chat') + with gr.Row(): + text_input.render() + send = gr.Button("Send", variant='primary', size='sm', scale=1) + + upload_flag = gr.State(value=0) + replace_flag = gr.State(value=0) + image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag]) + + with gr.Row(): + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag, + img_list], + ["examples_v2/sofa.jpg", "[detection] sofa", upload_flag, replace_flag, img_list], + ["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag, + img_list], + ["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek", + upload_flag, replace_flag, img_list], + ["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list], + ["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list], + ["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + + dataset.click( + gradio_taskselect, + inputs=[dataset], + outputs=[text_input, task_inst], + show_progress="hidden", + postprocess=False, + queue=False, + ) + + text_input.submit( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + send.click( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False) + +demo.launch(share=True, enable_queue=True) diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index 15af837..1828cef 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 max_txt_len: 160 end_sym: "###" diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml index 62709e1..f743497 100644 --- a/eval_configs/minigpt4_llama2_eval.yaml +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 max_txt_len: 160 end_sym: "" diff --git a/eval_configs/minigptv2_eval.yaml b/eval_configs/minigptv2_eval.yaml new file mode 100644 index 0000000..0980178 --- /dev/null +++ b/eval_configs/minigptv2_eval.yaml @@ -0,0 +1,25 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST]' + ckpt: '/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth' + llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update" + lora_r: 64 + lora_alpha: 16 + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml index c201bdc..af73f1d 100644 --- a/minigpt4/configs/models/minigpt4_llama2.yaml +++ b/minigpt4/configs/models/minigpt4_llama2.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 # vit encoder image_size: 224 diff --git a/minigpt4/configs/models/minigpt4_vicuna0.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml index 34bd2ed..783ec4e 100644 --- a/minigpt4/configs/models/minigpt4_vicuna0.yaml +++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 # vit encoder image_size: 224 diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt_v2.yaml new file mode 100755 index 0000000..018a18a --- /dev/null +++ b/minigpt4/configs/models/minigpt_v2.yaml @@ -0,0 +1,31 @@ +model: + arch: minigpt_v2 + + # vit encoder + image_size: 448 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # generation configs + prompt: "" + + llama_model: "/path/to/llama2/weight" + lora_r: 64 + lora_alpha: 16 + + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + eval: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py index 74626be..bc01b56 100644 --- a/minigpt4/models/__init__.py +++ b/minigpt4/models/__init__.py @@ -11,14 +11,18 @@ from omegaconf import OmegaConf from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel -from minigpt4.models.minigpt_4 import MiniGPT4 +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.minigpt4 import MiniGPT4 +from minigpt4.models.minigpt_v2 import MiniGPTv2 from minigpt4.processors.base_processor import BaseProcessor __all__ = [ "load_model", "BaseModel", + "MiniGPTBase", "MiniGPT4", + "MiniGPTv2" ] diff --git a/minigpt4/models/minigpt_4.py b/minigpt4/models/minigpt4.py similarity index 99% rename from minigpt4/models/minigpt_4.py rename to minigpt4/models/minigpt4.py index f73a800..a2e4798 100644 --- a/minigpt4/models/minigpt_4.py +++ b/minigpt4/models/minigpt4.py @@ -11,8 +11,7 @@ from minigpt4.models.minigpt_base import MiniGPTBase from minigpt4.models.Qformer import BertConfig, BertLMHeadModel - -@registry.register_model("mini_gpt4") +@registry.register_model("minigpt4") class MiniGPT4(MiniGPTBase): """ MiniGPT-4 model diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index aa91bf3..77c919a 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -25,10 +25,12 @@ class MiniGPTBase(BaseModel): freeze_vit=True, llama_model="", max_txt_len=32, + max_context_len=3800, + prompt_template="", end_sym='\n', low_resource=False, # use 8 bit and put vit in cpu device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - lora_r=0, # lora_r means lora is not used + lora_r=0, # lora_r means lora is not used lora_target_modules=["q_proj", "v_proj"], lora_alpha=16, lora_dropout=0.05, @@ -50,8 +52,10 @@ class MiniGPTBase(BaseModel): ) self.max_txt_len = max_txt_len + self.max_context_len = max_context_len self.end_sym = end_sym + self.prompt_template = prompt_template self.prompt_list = [] def vit_to_cpu(self): @@ -129,7 +133,6 @@ class MiniGPTBase(BaseModel): wrapped_atts[i, :length] = 1 return wrapped_embs, wrapped_atts - def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): """ Concatenate the batched input embedding and batched output embedding together. @@ -219,7 +222,7 @@ class MiniGPTBase(BaseModel): conv_q = [q.split(connect_sym)for q in conv_q] conv_a = [a.split(connect_sym) for a in conv_a] - conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] + conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) @@ -233,7 +236,7 @@ class MiniGPTBase(BaseModel): instruction = None if self.chat_template: - instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] + instruction = [self.prompt_template.format(instruct) for instruct in instruction] if 'length' in samples: # the input is a image train (like videos) diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py new file mode 100644 index 0000000..a046b0b --- /dev/null +++ b/minigpt4/models/minigpt_v2.py @@ -0,0 +1,139 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +@registry.register_model("minigpt_v2") +class MiniGPTv2(MiniGPTBase): + """ + MiniGPT-v2 model + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain": "configs/models/minigpt_v2.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=448, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_template='[INST] {} [/INST]', + max_txt_len=300, + end_sym='\n', + lora_r=64, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + chat_template=False, + use_grad_checkpoint_llm=False, + max_context_len=3800, + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + max_context_len=max_context_len, + end_sym=end_sym, + prompt_template=prompt_template, + low_resource=low_resource, + device_8bit=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + img_f_dim = self.visual_encoder.num_features * 4 + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + self.chat_template = chat_template + + if use_grad_checkpoint_llm: + self.llama_model.gradient_checkpointing_enable() + + def encode_img(self, image): + device = image.device + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + low_resource = cfg.get("low_resource", False) + + prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]') + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r", 64) + lora_alpha = cfg.get("lora_alpha", 16) + chat_template = cfg.get("chat_template", False) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r=lora_r, + lora_alpha=lora_alpha, + chat_template=chat_template, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml index 6920aab..f3981b8 100644 --- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml index 9a6ac2d..fa2b578 100644 --- a/train_configs/minigpt4_llama2_stage2_finetune.yaml +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 max_txt_len: 160 diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml index 4ec1597..be87b77 100644 --- a/train_configs/minigpt4_stage1_pretrain.yaml +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml index 54cedb4..404dfd6 100644 --- a/train_configs/minigpt4_stage2_finetune.yaml +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 max_txt_len: 160