update v2 demo
BIN
MiniGPT_4.pdf
16
demo.py
@ -7,10 +7,12 @@ import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import gradio as gr
|
||||
|
||||
from transformers import StoppingCriteriaList
|
||||
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
|
||||
stop_words_ids = [[835], [2277, 29937]]
|
||||
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
|
||||
print('Initialization Finished')
|
||||
|
||||
|
||||
@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
chat.encode_img(img_list)
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||
|
||||
|
||||
@ -124,7 +132,7 @@ with gr.Blocks() as demo:
|
||||
gr.Markdown(article)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column(scale=1):
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||
clear = gr.Button("Restart")
|
||||
@ -147,7 +155,7 @@ with gr.Blocks() as demo:
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column(scale=2):
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(label='MiniGPT-4')
|
||||
|
155
demo_v2.py
@ -1,37 +1,23 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from threading import Thread
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
from termcolor import colored
|
||||
from textwrap import wrap
|
||||
from torchvision.transforms import functional as F
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import html
|
||||
import gradio as gr
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
import minigpt4.tasks as tasks
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
||||
from minigpt4.common.logger import setup_logger
|
||||
from minigpt4.common.optims import (
|
||||
LinearWarmupCosineLRScheduler,
|
||||
LinearWarmupStepLRScheduler,
|
||||
)
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.common.utils import now
|
||||
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
@ -40,17 +26,22 @@ from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
|
||||
import torch.backends.cudnn as cudnn
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
|
||||
help="path to configuration file.")
|
||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
@ -60,19 +51,18 @@ cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
print('Initializing Chat')
|
||||
cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
|
||||
cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
|
||||
cfg.model_cfg.lora_r = 64
|
||||
cfg.model_cfg.lora_alpha = 16
|
||||
args = parse_args()
|
||||
cfg = Config(args)
|
||||
|
||||
device = 'cuda'
|
||||
device = 'cuda:{}'.format(args.gpu_id)
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to(device)
|
||||
bounding_box_size = 100
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
|
||||
model = model.eval()
|
||||
@ -484,6 +474,7 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||
|
||||
|
||||
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||
print('chat state', chat_state)
|
||||
if not isinstance(img_list[0], torch.Tensor):
|
||||
chat.encode_img(img_list)
|
||||
streamer = chat.stream_answer(conv=chat_state,
|
||||
@ -498,7 +489,7 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||
chatbot[-1][1] = output
|
||||
yield chatbot, chat_state
|
||||
# print('message: ', chat_state.messages)
|
||||
chat_state.messages[-1][1] = reverse_escape(output) + '</s>'
|
||||
chat_state.messages[-1][1] = '</s>'
|
||||
return chatbot, chat_state
|
||||
|
||||
|
||||
@ -538,102 +529,6 @@ def gradio_taskselect(idx):
|
||||
return prompt_list[idx], instruct_list[idx]
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
generation_kwargs = dict(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
do_sample=True,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
return generation_kwargs
|
||||
|
||||
def answer(self, conv, img_list, **kargs):
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
|
||||
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def stream_answer(self, conv, img_list, **kargs):
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||
generation_kwargs['streamer'] = streamer
|
||||
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
return streamer
|
||||
|
||||
def encode_img(self, img_list):
|
||||
image = img_list[0]
|
||||
img_list.pop(0)
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
img_list.append(image_emb)
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||
img_list.append(image)
|
||||
msg = "Received."
|
||||
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
|
||||
|
||||
chat = Chat(model, vis_processor, device=device)
|
||||
|
@ -7,12 +7,12 @@ dependencies:
|
||||
- python=3.9
|
||||
- cudatoolkit
|
||||
- pip
|
||||
- pytorch=1.12.1
|
||||
- pytorch=2.0.0
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- torchaudio=0.12.1
|
||||
- torchvision=0.13.1
|
||||
- pip:
|
||||
- accelerate==0.16.0
|
||||
- accelerate==0.20.3
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
- async-timeout==4.0.2
|
||||
@ -25,7 +25,7 @@ dependencies:
|
||||
- filelock==3.9.0
|
||||
- fonttools==4.38.0
|
||||
- frozenlist==1.3.3
|
||||
- huggingface-hub==0.13.4
|
||||
- huggingface-hub==0.18.0
|
||||
- importlib-resources==5.12.0
|
||||
- kiwisolver==1.4.4
|
||||
- matplotlib==3.7.0
|
||||
@ -40,7 +40,7 @@ dependencies:
|
||||
- regex==2022.10.31
|
||||
- tokenizers==0.13.2
|
||||
- tqdm==4.64.1
|
||||
- transformers==4.28.0
|
||||
- transformers==4.32.0
|
||||
- timm==0.6.13
|
||||
- spacy==3.5.1
|
||||
- webdataset==0.2.48
|
||||
@ -53,11 +53,10 @@ dependencies:
|
||||
- iopath==0.1.10
|
||||
- decord==0.6.0
|
||||
- tenacity==8.2.2
|
||||
- peft
|
||||
- peft==0.2.0
|
||||
- pycocoevalcap
|
||||
- sentence-transformers
|
||||
- umap-learn
|
||||
- notebook
|
||||
- gradio==3.24.1
|
||||
- gradio-client==0.0.8
|
||||
- gradio==3.47.1
|
||||
- wandb
|
||||
|
BIN
examples_v2/2000x1372_wmkn_0012149409555.jpg
Executable file
After Width: | Height: | Size: 91 KiB |
BIN
examples_v2/KFC-20-for-20-Nuggets.jpg
Executable file
After Width: | Height: | Size: 83 KiB |
BIN
examples_v2/cockdial.png
Executable file
After Width: | Height: | Size: 1.5 MiB |
BIN
examples_v2/float.png
Executable file
After Width: | Height: | Size: 1.2 MiB |
BIN
examples_v2/glip_test.jpg
Executable file
After Width: | Height: | Size: 92 KiB |
BIN
examples_v2/office.jpg
Executable file
After Width: | Height: | Size: 25 KiB |
BIN
examples_v2/sofa.jpg
Executable file
After Width: | Height: | Size: 116 KiB |
BIN
examples_v2/thief.png
Executable file
After Width: | Height: | Size: 865 KiB |
@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import time
|
||||
from threading import Thread
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
if stopping_criteria is not None:
|
||||
self.stopping_criteria = stopping_criteria
|
||||
else:
|
||||
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
@ -144,8 +148,8 @@ class Chat:
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
@ -154,10 +158,9 @@ class Chat:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
outputs = self.model.llama_model.generate(
|
||||
generation_kwargs = dict(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
@ -169,18 +172,31 @@ class Chat:
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||
return generation_kwargs
|
||||
|
||||
def answer(self, conv, img_list, **kargs):
|
||||
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
||||
|
||||
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
def stream_answer(self, conv, img_list, **kargs):
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||
generation_kwargs['streamer'] = streamer
|
||||
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
return streamer
|
||||
|
||||
def encode_img(self, img_list):
|
||||
image = img_list[0]
|
||||
img_list.pop(0)
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
@ -194,9 +210,12 @@ class Chat:
|
||||
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
img_list.append(image_emb)
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||
img_list.append(image)
|
||||
msg = "Received."
|
||||
# self.conv.append_message(self.conv.roles[1], msg)
|
||||
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
@ -209,7 +228,9 @@ class Chat:
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
print('debug device: ', self.device)
|
||||
print('debug model device: ', self.model.device)
|
||||
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
@ -169,7 +169,7 @@ class BaseModel(nn.Module):
|
||||
return visual_encoder, ln_vision
|
||||
|
||||
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
|
||||
**lora_kargs):
|
||||
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
|
||||
logging.info('Loading LLAMA')
|
||||
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
|
||||
llama_tokenizer.pad_token = "$$"
|
||||
@ -193,6 +193,7 @@ class BaseModel(nn.Module):
|
||||
r=lora_r,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=lora_target_modules,
|
||||
**lora_kargs
|
||||
)
|
||||
llama_model = get_peft_model(llama_model, loraconfig)
|
||||
|