update v2 demo

This commit is contained in:
Deyao Zhu 2023-10-13 03:14:35 +03:00
parent 7a575af639
commit 0eba23ce3b
14 changed files with 84 additions and 160 deletions

Binary file not shown.

16
demo.py
View File

@ -7,10 +7,12 @@ import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import gradio as gr import gradio as gr
from transformers import StoppingCriteriaList
from minigpt4.common.config import Config from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2 from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration # imports modules for registration
from minigpt4.datasets.builders import * from minigpt4.datasets.builders import *
@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished') print('Initialization Finished')
@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy() chat_state = CONV_VISION.copy()
img_list = [] img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list) llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
@ -124,7 +132,7 @@ with gr.Blocks() as demo:
gr.Markdown(article) gr.Markdown(article)
with gr.Row(): with gr.Row():
with gr.Column(scale=0.5): with gr.Column(scale=1):
image = gr.Image(type="pil") image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart") clear = gr.Button("Restart")
@ -147,7 +155,7 @@ with gr.Blocks() as demo:
label="Temperature", label="Temperature",
) )
with gr.Column(): with gr.Column(scale=2):
chat_state = gr.State() chat_state = gr.State()
img_list = gr.State() img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4') chatbot = gr.Chatbot(label='MiniGPT-4')

View File

@ -1,37 +1,23 @@
import argparse import argparse
import os import os
import random import random
import requests
from io import BytesIO
from threading import Thread
from collections import defaultdict from collections import defaultdict
import cv2 import cv2
from termcolor import colored
from textwrap import wrap
from torchvision.transforms import functional as F
import re import re
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import torch.backends.cudnn as cudnn
import html import html
import gradio as gr import gradio as gr
from transformers import TextIteratorStreamer
import torch.backends.cudnn as cudnn
import minigpt4.tasks as tasks
from minigpt4.common.config import Config from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
from minigpt4.common.logger import setup_logger
from minigpt4.common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
)
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.common.utils import now from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
# imports modules for registration # imports modules for registration
from minigpt4.datasets.builders import * from minigpt4.datasets.builders import *
@ -40,17 +26,22 @@ from minigpt4.processors import *
from minigpt4.runners import * from minigpt4.runners import *
from minigpt4.tasks import * from minigpt4.tasks import *
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
import torch.backends.cudnn as cudnn def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
random.seed(42) random.seed(42)
np.random.seed(42) np.random.seed(42)
@ -60,19 +51,18 @@ cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
print('Initializing Chat') print('Initializing Chat')
cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml'])) args = parse_args()
cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth" cfg = Config(args)
cfg.model_cfg.lora_r = 64
cfg.model_cfg.lora_alpha = 16
device = 'cuda' device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch) model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device) model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100 bounding_box_size = 100
vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval() model = model.eval()
@ -484,6 +474,7 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
def gradio_stream_answer(chatbot, chat_state, img_list, temperature): def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
print('chat state', chat_state)
if not isinstance(img_list[0], torch.Tensor): if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list) chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state, streamer = chat.stream_answer(conv=chat_state,
@ -498,7 +489,7 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
chatbot[-1][1] = output chatbot[-1][1] = output
yield chatbot, chat_state yield chatbot, chat_state
# print('message: ', chat_state.messages) # print('message: ', chat_state.messages)
chat_state.messages[-1][1] = reverse_escape(output) + '</s>' chat_state.messages[-1][1] = '</s>'
return chatbot, chat_state return chatbot, chat_state
@ -538,102 +529,6 @@ def gradio_taskselect(idx):
return prompt_list[idx], instruct_list[idx] return prompt_list[idx], instruct_list[idx]
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
self.device = device
self.model = model
self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
return generation_kwargs
def answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model.llama_model.generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
img_list.append(image)
msg = "Received."
return msg
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
chat = Chat(model, vis_processor, device=device) chat = Chat(model, vis_processor, device=device)

View File

@ -7,12 +7,12 @@ dependencies:
- python=3.9 - python=3.9
- cudatoolkit - cudatoolkit
- pip - pip
- pytorch=1.12.1 - pytorch=2.0.0
- pytorch-mutex=1.0=cuda - pytorch-mutex=1.0=cuda
- torchaudio=0.12.1 - torchaudio=0.12.1
- torchvision=0.13.1 - torchvision=0.13.1
- pip: - pip:
- accelerate==0.16.0 - accelerate==0.20.3
- aiohttp==3.8.4 - aiohttp==3.8.4
- aiosignal==1.3.1 - aiosignal==1.3.1
- async-timeout==4.0.2 - async-timeout==4.0.2
@ -25,7 +25,7 @@ dependencies:
- filelock==3.9.0 - filelock==3.9.0
- fonttools==4.38.0 - fonttools==4.38.0
- frozenlist==1.3.3 - frozenlist==1.3.3
- huggingface-hub==0.13.4 - huggingface-hub==0.18.0
- importlib-resources==5.12.0 - importlib-resources==5.12.0
- kiwisolver==1.4.4 - kiwisolver==1.4.4
- matplotlib==3.7.0 - matplotlib==3.7.0
@ -40,7 +40,7 @@ dependencies:
- regex==2022.10.31 - regex==2022.10.31
- tokenizers==0.13.2 - tokenizers==0.13.2
- tqdm==4.64.1 - tqdm==4.64.1
- transformers==4.28.0 - transformers==4.32.0
- timm==0.6.13 - timm==0.6.13
- spacy==3.5.1 - spacy==3.5.1
- webdataset==0.2.48 - webdataset==0.2.48
@ -53,11 +53,10 @@ dependencies:
- iopath==0.1.10 - iopath==0.1.10
- decord==0.6.0 - decord==0.6.0
- tenacity==8.2.2 - tenacity==8.2.2
- peft - peft==0.2.0
- pycocoevalcap - pycocoevalcap
- sentence-transformers - sentence-transformers
- umap-learn - umap-learn
- notebook - notebook
- gradio==3.24.1 - gradio==3.47.1
- gradio-client==0.0.8
- wandb - wandb

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

BIN
examples_v2/cockdial.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

BIN
examples_v2/float.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

BIN
examples_v2/glip_test.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
examples_v2/office.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
examples_v2/sofa.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

BIN
examples_v2/thief.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

View File

@ -1,10 +1,11 @@
import argparse import argparse
import time import time
from threading import Thread
from PIL import Image from PIL import Image
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import dataclasses import dataclasses
from enum import auto, Enum from enum import auto, Enum
@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
class Chat: class Chat:
def __init__(self, model, vis_processor, device='cuda:0'): def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
self.device = device self.device = device
self.model = model self.model = model
self.vis_processor = vis_processor self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([835]).to(self.device),
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. if stopping_criteria is not None:
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) self.stopping_criteria = stopping_criteria
else:
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv): def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
@ -144,8 +148,8 @@ class Chat:
else: else:
conv.append_message(conv.roles[0], text) conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list) embs = self.get_context_emb(conv, img_list)
@ -154,10 +158,9 @@ class Chat:
print('Warning: The number of tokens in current conversation exceeds the max length. ' print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.') 'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length) begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:] embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate( generation_kwargs = dict(
inputs_embeds=embs, inputs_embeds=embs,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria, stopping_criteria=self.stopping_criteria,
@ -169,18 +172,31 @@ class Chat:
length_penalty=length_penalty, length_penalty=length_penalty,
temperature=temperature, temperature=temperature,
) )
output_token = outputs[0] return generation_kwargs
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:] def answer(self, conv, img_list, **kargs):
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it generation_dict = self.answer_prepare(conv, img_list, **kargs)
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) output_token = self.model.llama_model.generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip() output_text = output_text.split('Assistant:')[-1].strip()
conv.messages[-1][1] = output_text conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy() return output_text, output_token.cpu().numpy()
def upload_img(self, image, conv, img_list): def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB') raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@ -194,9 +210,12 @@ class Chat:
image_emb, _ = self.model.encode_img(image) image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb) img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>") conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
img_list.append(image)
msg = "Received." msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg return msg
def get_context_emb(self, conv, img_list): def get_context_emb(self, conv, img_list):
@ -209,7 +228,9 @@ class Chat:
# only add bos to the first seg # only add bos to the first seg
for i, seg in enumerate(prompt_segs) for i, seg in enumerate(prompt_segs)
] ]
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] print('debug device: ', self.device)
print('debug model device: ', self.model.device)
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1) mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs return mixed_embs

View File

@ -169,7 +169,7 @@ class BaseModel(nn.Module):
return visual_encoder, ln_vision return visual_encoder, ln_vision
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0, def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
**lora_kargs): lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA') logging.info('Loading LLAMA')
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
llama_tokenizer.pad_token = "$$" llama_tokenizer.pad_token = "$$"
@ -193,6 +193,7 @@ class BaseModel(nn.Module):
r=lora_r, r=lora_r,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
target_modules=lora_target_modules,
**lora_kargs **lora_kargs
) )
llama_model = get_peft_model(llama_model, loraconfig) llama_model = get_peft_model(llama_model, loraconfig)