diff --git a/MiniGPT_4.pdf b/MiniGPT_4.pdf deleted file mode 100644 index 5450815..0000000 Binary files a/MiniGPT_4.pdf and /dev/null differ diff --git a/demo.py b/demo.py index 483b56c..c7646c4 100644 --- a/demo.py +++ b/demo.py @@ -7,10 +7,12 @@ import torch import torch.backends.cudnn as cudnn import gradio as gr +from transformers import StoppingCriteriaList + from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry -from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2 +from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub # imports modules for registration from minigpt4.datasets.builders import * @@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type] vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) -chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) + +stop_words_ids = [[835], [2277, 29937]] +stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids] +stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria) print('Initialization Finished') @@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state): chat_state = CONV_VISION.copy() img_list = [] llm_message = chat.upload_img(gr_img, chat_state, img_list) + chat.encode_img(img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list @@ -124,7 +132,7 @@ with gr.Blocks() as demo: gr.Markdown(article) with gr.Row(): - with gr.Column(scale=0.5): + with gr.Column(scale=1): image = gr.Image(type="pil") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") @@ -147,7 +155,7 @@ with gr.Blocks() as demo: label="Temperature", ) - with gr.Column(): + with gr.Column(scale=2): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='MiniGPT-4') diff --git a/demo_v2.py b/demo_v2.py index 1ea87fc..52fb897 100644 --- a/demo_v2.py +++ b/demo_v2.py @@ -1,37 +1,23 @@ import argparse import os import random -import requests -from io import BytesIO -from threading import Thread from collections import defaultdict import cv2 -from termcolor import colored -from textwrap import wrap -from torchvision.transforms import functional as F import re import numpy as np from PIL import Image import torch -import torch.backends.cudnn as cudnn import html import gradio as gr -from transformers import TextIteratorStreamer +import torch.backends.cudnn as cudnn -import minigpt4.tasks as tasks from minigpt4.common.config import Config -from minigpt4.common.dist_utils import get_rank, init_distributed_mode -from minigpt4.common.logger import setup_logger -from minigpt4.common.optims import ( - LinearWarmupCosineLRScheduler, - LinearWarmupStepLRScheduler, -) + from minigpt4.common.registry import registry -from minigpt4.common.utils import now -from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat # imports modules for registration from minigpt4.datasets.builders import * @@ -40,17 +26,22 @@ from minigpt4.processors import * from minigpt4.runners import * from minigpt4.tasks import * -parser = argparse.ArgumentParser(description="Demo") -parser.add_argument("--cfg-path", required=True, help="path to configuration file.") -parser.add_argument( - "--options", - nargs="+", - help="override some settings in the used config, the key-value pair " - "in xxx=yyy format will be merged into config file (deprecate), " - "change to --cfg-options instead.", -) -import torch.backends.cudnn as cudnn +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml', + help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + random.seed(42) np.random.seed(42) @@ -60,19 +51,18 @@ cudnn.benchmark = False cudnn.deterministic = True print('Initializing Chat') -cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml'])) -cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth" -cfg.model_cfg.lora_r = 64 -cfg.model_cfg.lora_alpha = 16 +args = parse_args() +cfg = Config(args) -device = 'cuda' +device = 'cuda:{}'.format(args.gpu_id) model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to(device) bounding_box_size = 100 -vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train +vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) model = model.eval() @@ -484,6 +474,7 @@ def gradio_answer(chatbot, chat_state, img_list, temperature): def gradio_stream_answer(chatbot, chat_state, img_list, temperature): + print('chat state', chat_state) if not isinstance(img_list[0], torch.Tensor): chat.encode_img(img_list) streamer = chat.stream_answer(conv=chat_state, @@ -498,7 +489,7 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature): chatbot[-1][1] = output yield chatbot, chat_state # print('message: ', chat_state.messages) - chat_state.messages[-1][1] = reverse_escape(output) + '' + chat_state.messages[-1][1] = '' return chatbot, chat_state @@ -538,102 +529,6 @@ def gradio_taskselect(idx): return prompt_list[idx], instruct_list[idx] -class Chat: - def __init__(self, model, vis_processor, device='cuda:0'): - self.device = device - self.model = model - self.vis_processor = vis_processor - stop_words_ids = [torch.tensor([2]).to(self.device)] - self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) - - def ask(self, text, conv): - if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ - and conv.messages[-1][1][-6:] == '': # last message is image. - conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) - else: - conv.append_message(conv.roles[0], text) - - def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): - conv.append_message(conv.roles[1], None) - embs = self.get_context_emb(conv, img_list) - - current_max_len = embs.shape[1] + max_new_tokens - if current_max_len - max_length > 0: - print('Warning: The number of tokens in current conversation exceeds the max length. ' - 'The model will not see the contexts outside the range.') - begin_idx = max(0, current_max_len - max_length) - embs = embs[:, begin_idx:] - - generation_kwargs = dict( - inputs_embeds=embs, - max_new_tokens=max_new_tokens, - stopping_criteria=self.stopping_criteria, - num_beams=num_beams, - do_sample=True, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - temperature=temperature, - ) - return generation_kwargs - - def answer(self, conv, img_list, **kargs): - generation_kwargs = self.answer_prepare(conv, img_list, **kargs) - - output_token = self.model.llama_model.generate(**generation_dict)[0] - output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) - conv.messages[-1][1] = output_text - return output_text, output_token.cpu().numpy() - - def stream_answer(self, conv, img_list, **kargs): - generation_kwargs = self.answer_prepare(conv, img_list, **kargs) - streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) - generation_kwargs['streamer'] = streamer - thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) - thread.start() - return streamer - - def encode_img(self, img_list): - image = img_list[0] - img_list.pop(0) - if isinstance(image, str): # is a image path - raw_image = Image.open(image).convert('RGB') - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, Image.Image): - raw_image = image - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, torch.Tensor): - if len(image.shape) == 3: - image = image.unsqueeze(0) - image = image.to(self.device) - - image_emb, _ = self.model.encode_img(image) - img_list.append(image_emb) - - def upload_img(self, image, conv, img_list): - conv.append_message(conv.roles[0], "") - img_list.append(image) - msg = "Received." - - return msg - - def get_context_emb(self, conv, img_list): - prompt = conv.get_prompt() - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.model.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - chat = Chat(model, vis_processor, device=device) diff --git a/environment.yml b/environment.yml index d5cfcf8..c17288d 100644 --- a/environment.yml +++ b/environment.yml @@ -7,12 +7,12 @@ dependencies: - python=3.9 - cudatoolkit - pip - - pytorch=1.12.1 + - pytorch=2.0.0 - pytorch-mutex=1.0=cuda - torchaudio=0.12.1 - torchvision=0.13.1 - pip: - - accelerate==0.16.0 + - accelerate==0.20.3 - aiohttp==3.8.4 - aiosignal==1.3.1 - async-timeout==4.0.2 @@ -25,7 +25,7 @@ dependencies: - filelock==3.9.0 - fonttools==4.38.0 - frozenlist==1.3.3 - - huggingface-hub==0.13.4 + - huggingface-hub==0.18.0 - importlib-resources==5.12.0 - kiwisolver==1.4.4 - matplotlib==3.7.0 @@ -40,7 +40,7 @@ dependencies: - regex==2022.10.31 - tokenizers==0.13.2 - tqdm==4.64.1 - - transformers==4.28.0 + - transformers==4.32.0 - timm==0.6.13 - spacy==3.5.1 - webdataset==0.2.48 @@ -53,11 +53,10 @@ dependencies: - iopath==0.1.10 - decord==0.6.0 - tenacity==8.2.2 - - peft + - peft==0.2.0 - pycocoevalcap - sentence-transformers - umap-learn - notebook - - gradio==3.24.1 - - gradio-client==0.0.8 + - gradio==3.47.1 - wandb diff --git a/examples_v2/2000x1372_wmkn_0012149409555.jpg b/examples_v2/2000x1372_wmkn_0012149409555.jpg new file mode 100755 index 0000000..1250f7f Binary files /dev/null and b/examples_v2/2000x1372_wmkn_0012149409555.jpg differ diff --git a/examples_v2/KFC-20-for-20-Nuggets.jpg b/examples_v2/KFC-20-for-20-Nuggets.jpg new file mode 100755 index 0000000..0ec641c Binary files /dev/null and b/examples_v2/KFC-20-for-20-Nuggets.jpg differ diff --git a/examples_v2/cockdial.png b/examples_v2/cockdial.png new file mode 100755 index 0000000..935f98e Binary files /dev/null and b/examples_v2/cockdial.png differ diff --git a/examples_v2/float.png b/examples_v2/float.png new file mode 100755 index 0000000..900dcb0 Binary files /dev/null and b/examples_v2/float.png differ diff --git a/examples_v2/glip_test.jpg b/examples_v2/glip_test.jpg new file mode 100755 index 0000000..f9198f2 Binary files /dev/null and b/examples_v2/glip_test.jpg differ diff --git a/examples_v2/office.jpg b/examples_v2/office.jpg new file mode 100755 index 0000000..e35bdc2 Binary files /dev/null and b/examples_v2/office.jpg differ diff --git a/examples_v2/sofa.jpg b/examples_v2/sofa.jpg new file mode 100755 index 0000000..8610591 Binary files /dev/null and b/examples_v2/sofa.jpg differ diff --git a/examples_v2/thief.png b/examples_v2/thief.png new file mode 100755 index 0000000..579ee52 Binary files /dev/null and b/examples_v2/thief.png differ diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 7678814..9c27c78 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -1,10 +1,11 @@ import argparse import time +from threading import Thread from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer -from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import dataclasses from enum import auto, Enum @@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation( class Chat: - def __init__(self, model, vis_processor, device='cuda:0'): + def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None): self.device = device self.model = model self.vis_processor = vis_processor - stop_words_ids = [torch.tensor([835]).to(self.device), - torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. - self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + if stopping_criteria is not None: + self.stopping_criteria = stopping_criteria + else: + stop_words_ids = [torch.tensor([2]).to(self.device)] + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def ask(self, text, conv): if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ @@ -144,8 +148,8 @@ class Chat: else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) @@ -154,10 +158,9 @@ class Chat: print('Warning: The number of tokens in current conversation exceeds the max length. ' 'The model will not see the contexts outside the range.') begin_idx = max(0, current_max_len - max_length) - embs = embs[:, begin_idx:] - outputs = self.model.llama_model.generate( + generation_kwargs = dict( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, @@ -169,18 +172,31 @@ class Chat: length_penalty=length_penalty, temperature=temperature, ) - output_token = outputs[0] - if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it - output_token = output_token[1:] - if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it - output_token = output_token[1:] - output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + return generation_kwargs + + def answer(self, conv, img_list, **kargs): + generation_dict = self.answer_prepare(conv, img_list, **kargs) + + output_token = self.model.llama_model.generate(**generation_dict)[0] + output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() + conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy() - def upload_img(self, image, conv, img_list): + def stream_answer(self, conv, img_list, **kargs): + generation_kwargs = self.answer_prepare(conv, img_list, **kargs) + streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) + generation_kwargs['streamer'] = streamer + thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + + def encode_img(self, img_list): + image = img_list[0] + img_list.pop(0) if isinstance(image, str): # is a image path raw_image = Image.open(image).convert('RGB') image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) @@ -194,9 +210,12 @@ class Chat: image_emb, _ = self.model.encode_img(image) img_list.append(image_emb) + + def upload_img(self, image, conv, img_list): conv.append_message(conv.roles[0], "") + img_list.append(image) msg = "Received." - # self.conv.append_message(self.conv.roles[1], msg) + return msg def get_context_emb(self, conv, img_list): @@ -209,7 +228,9 @@ class Chat: # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] - seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + print('debug device: ', self.device) + print('debug model device: ', self.model.device) + seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index ae0a3be..fd1d636 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -169,7 +169,7 @@ class BaseModel(nn.Module): return visual_encoder, ln_vision def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0, - **lora_kargs): + lora_target_modules=["q_proj","v_proj"], **lora_kargs): logging.info('Loading LLAMA') llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) llama_tokenizer.pad_token = "$$" @@ -193,6 +193,7 @@ class BaseModel(nn.Module): r=lora_r, bias="none", task_type="CAUSAL_LM", + target_modules=lora_target_modules, **lora_kargs ) llama_model = get_peft_model(llama_model, loraconfig)