update v2 demo

This commit is contained in:
Deyao Zhu 2023-10-13 03:14:35 +03:00
parent 7a575af639
commit 0eba23ce3b
14 changed files with 84 additions and 160 deletions

Binary file not shown.

16
demo.py
View File

@ -7,10 +7,12 @@ import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from transformers import StoppingCriteriaList
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration
from minigpt4.datasets.builders import *
@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
@ -124,7 +132,7 @@ with gr.Blocks() as demo:
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column(scale=1):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
@ -147,7 +155,7 @@ with gr.Blocks() as demo:
label="Temperature",
)
with gr.Column():
with gr.Column(scale=2):
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')

View File

@ -1,37 +1,23 @@
import argparse
import os
import random
import requests
from io import BytesIO
from threading import Thread
from collections import defaultdict
import cv2
from termcolor import colored
from textwrap import wrap
from torchvision.transforms import functional as F
import re
import numpy as np
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
import html
import gradio as gr
from transformers import TextIteratorStreamer
import torch.backends.cudnn as cudnn
import minigpt4.tasks as tasks
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
from minigpt4.common.logger import setup_logger
from minigpt4.common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
)
from minigpt4.common.registry import registry
from minigpt4.common.utils import now
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
# imports modules for registration
from minigpt4.datasets.builders import *
@ -40,17 +26,22 @@ from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
import torch.backends.cudnn as cudnn
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
random.seed(42)
np.random.seed(42)
@ -60,19 +51,18 @@ cudnn.benchmark = False
cudnn.deterministic = True
print('Initializing Chat')
cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
cfg.model_cfg.lora_r = 64
cfg.model_cfg.lora_alpha = 16
args = parse_args()
cfg = Config(args)
device = 'cuda'
device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100
vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval()
@ -484,6 +474,7 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
print('chat state', chat_state)
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
@ -498,7 +489,7 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
chatbot[-1][1] = output
yield chatbot, chat_state
# print('message: ', chat_state.messages)
chat_state.messages[-1][1] = reverse_escape(output) + '</s>'
chat_state.messages[-1][1] = '</s>'
return chatbot, chat_state
@ -538,102 +529,6 @@ def gradio_taskselect(idx):
return prompt_list[idx], instruct_list[idx]
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
self.device = device
self.model = model
self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
return generation_kwargs
def answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model.llama_model.generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
img_list.append(image)
msg = "Received."
return msg
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
chat = Chat(model, vis_processor, device=device)

View File

@ -7,12 +7,12 @@ dependencies:
- python=3.9
- cudatoolkit
- pip
- pytorch=1.12.1
- pytorch=2.0.0
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
- accelerate==0.16.0
- accelerate==0.20.3
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
@ -25,7 +25,7 @@ dependencies:
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- huggingface-hub==0.13.4
- huggingface-hub==0.18.0
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- matplotlib==3.7.0
@ -40,7 +40,7 @@ dependencies:
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.28.0
- transformers==4.32.0
- timm==0.6.13
- spacy==3.5.1
- webdataset==0.2.48
@ -53,11 +53,10 @@ dependencies:
- iopath==0.1.10
- decord==0.6.0
- tenacity==8.2.2
- peft
- peft==0.2.0
- pycocoevalcap
- sentence-transformers
- umap-learn
- notebook
- gradio==3.24.1
- gradio-client==0.0.8
- gradio==3.47.1
- wandb

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

BIN
examples_v2/cockdial.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

BIN
examples_v2/float.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

BIN
examples_v2/glip_test.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
examples_v2/office.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
examples_v2/sofa.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

BIN
examples_v2/thief.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

View File

@ -1,10 +1,11 @@
import argparse
import time
from threading import Thread
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import dataclasses
from enum import auto, Enum
@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
self.device = device
self.model = model
self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([835]).to(self.device),
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
if stopping_criteria is not None:
self.stopping_criteria = stopping_criteria
else:
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
@ -144,8 +148,8 @@ class Chat:
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
@ -154,10 +158,9 @@ class Chat:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate(
generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
@ -169,18 +172,31 @@ class Chat:
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
return generation_kwargs
def answer(self, conv, img_list, **kargs):
generation_dict = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model.llama_model.generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def upload_img(self, image, conv, img_list):
def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@ -194,9 +210,12 @@ class Chat:
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
img_list.append(image)
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
@ -209,7 +228,9 @@ class Chat:
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
print('debug device: ', self.device)
print('debug model device: ', self.model.device)
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs

View File

@ -169,7 +169,7 @@ class BaseModel(nn.Module):
return visual_encoder, ln_vision
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
**lora_kargs):
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA')
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
llama_tokenizer.pad_token = "$$"
@ -193,6 +193,7 @@ class BaseModel(nn.Module):
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
**lora_kargs
)
llama_model = get_peft_model(llama_model, loraconfig)