diff --git a/MiniGPT_4.pdf b/MiniGPT_4.pdf
deleted file mode 100644
index 5450815..0000000
Binary files a/MiniGPT_4.pdf and /dev/null differ
diff --git a/demo.py b/demo.py
index 483b56c..c7646c4 100644
--- a/demo.py
+++ b/demo.py
@@ -7,10 +7,12 @@ import torch
import torch.backends.cudnn as cudnn
import gradio as gr
+from transformers import StoppingCriteriaList
+
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
-from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
+from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration
from minigpt4.datasets.builders import *
@@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
-chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
+
+stop_words_ids = [[835], [2277, 29937]]
+stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
+stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+
+chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
@@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
@@ -124,7 +132,7 @@ with gr.Blocks() as demo:
gr.Markdown(article)
with gr.Row():
- with gr.Column(scale=0.5):
+ with gr.Column(scale=1):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
@@ -147,7 +155,7 @@ with gr.Blocks() as demo:
label="Temperature",
)
- with gr.Column():
+ with gr.Column(scale=2):
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')
diff --git a/demo_v2.py b/demo_v2.py
index 1ea87fc..52fb897 100644
--- a/demo_v2.py
+++ b/demo_v2.py
@@ -1,37 +1,23 @@
import argparse
import os
import random
-import requests
-from io import BytesIO
-from threading import Thread
from collections import defaultdict
import cv2
-from termcolor import colored
-from textwrap import wrap
-from torchvision.transforms import functional as F
import re
import numpy as np
from PIL import Image
import torch
-import torch.backends.cudnn as cudnn
import html
import gradio as gr
-from transformers import TextIteratorStreamer
+import torch.backends.cudnn as cudnn
-import minigpt4.tasks as tasks
from minigpt4.common.config import Config
-from minigpt4.common.dist_utils import get_rank, init_distributed_mode
-from minigpt4.common.logger import setup_logger
-from minigpt4.common.optims import (
- LinearWarmupCosineLRScheduler,
- LinearWarmupStepLRScheduler,
-)
+
from minigpt4.common.registry import registry
-from minigpt4.common.utils import now
-from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
+from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
# imports modules for registration
from minigpt4.datasets.builders import *
@@ -40,17 +26,22 @@ from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
-parser = argparse.ArgumentParser(description="Demo")
-parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
-parser.add_argument(
- "--options",
- nargs="+",
- help="override some settings in the used config, the key-value pair "
- "in xxx=yyy format will be merged into config file (deprecate), "
- "change to --cfg-options instead.",
-)
-import torch.backends.cudnn as cudnn
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
+ help="path to configuration file.")
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ args = parser.parse_args()
+ return args
+
random.seed(42)
np.random.seed(42)
@@ -60,19 +51,18 @@ cudnn.benchmark = False
cudnn.deterministic = True
print('Initializing Chat')
-cfg = Config(parser.parse_args(['--cfg-path', 'eval_configs/minigpt4_object_detection_448x448_llama2.yaml']))
-cfg.model_cfg.ckpt = "/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth"
-cfg.model_cfg.lora_r = 64
-cfg.model_cfg.lora_alpha = 16
+args = parse_args()
+cfg = Config(args)
-device = 'cuda'
+device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg
+model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100
-vis_processor_cfg = cfg.datasets_cfg.coco.vis_processor.train
+vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval()
@@ -484,6 +474,7 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
+ print('chat state', chat_state)
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
@@ -498,7 +489,7 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
chatbot[-1][1] = output
yield chatbot, chat_state
# print('message: ', chat_state.messages)
- chat_state.messages[-1][1] = reverse_escape(output) + ''
+ chat_state.messages[-1][1] = ''
return chatbot, chat_state
@@ -538,102 +529,6 @@ def gradio_taskselect(idx):
return prompt_list[idx], instruct_list[idx]
-class Chat:
- def __init__(self, model, vis_processor, device='cuda:0'):
- self.device = device
- self.model = model
- self.vis_processor = vis_processor
- stop_words_ids = [torch.tensor([2]).to(self.device)]
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
-
- def ask(self, text, conv):
- if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
- and conv.messages[-1][1][-6:] == '': # last message is image.
- conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
- else:
- conv.append_message(conv.roles[0], text)
-
- def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
- repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
- conv.append_message(conv.roles[1], None)
- embs = self.get_context_emb(conv, img_list)
-
- current_max_len = embs.shape[1] + max_new_tokens
- if current_max_len - max_length > 0:
- print('Warning: The number of tokens in current conversation exceeds the max length. '
- 'The model will not see the contexts outside the range.')
- begin_idx = max(0, current_max_len - max_length)
- embs = embs[:, begin_idx:]
-
- generation_kwargs = dict(
- inputs_embeds=embs,
- max_new_tokens=max_new_tokens,
- stopping_criteria=self.stopping_criteria,
- num_beams=num_beams,
- do_sample=True,
- min_length=min_length,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- length_penalty=length_penalty,
- temperature=temperature,
- )
- return generation_kwargs
-
- def answer(self, conv, img_list, **kargs):
- generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
-
- output_token = self.model.llama_model.generate(**generation_dict)[0]
- output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
- conv.messages[-1][1] = output_text
- return output_text, output_token.cpu().numpy()
-
- def stream_answer(self, conv, img_list, **kargs):
- generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
- streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
- generation_kwargs['streamer'] = streamer
- thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
- thread.start()
- return streamer
-
- def encode_img(self, img_list):
- image = img_list[0]
- img_list.pop(0)
- if isinstance(image, str): # is a image path
- raw_image = Image.open(image).convert('RGB')
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, Image.Image):
- raw_image = image
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, torch.Tensor):
- if len(image.shape) == 3:
- image = image.unsqueeze(0)
- image = image.to(self.device)
-
- image_emb, _ = self.model.encode_img(image)
- img_list.append(image_emb)
-
- def upload_img(self, image, conv, img_list):
- conv.append_message(conv.roles[0], "
")
- img_list.append(image)
- msg = "Received."
-
- return msg
-
- def get_context_emb(self, conv, img_list):
- prompt = conv.get_prompt()
- prompt_segs = prompt.split('')
- assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
- seg_tokens = [
- self.model.llama_tokenizer(
- seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
- # only add bos to the first seg
- for i, seg in enumerate(prompt_segs)
- ]
- seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
- mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
- mixed_embs = torch.cat(mixed_embs, dim=1)
- return mixed_embs
-
chat = Chat(model, vis_processor, device=device)
diff --git a/environment.yml b/environment.yml
index d5cfcf8..c17288d 100644
--- a/environment.yml
+++ b/environment.yml
@@ -7,12 +7,12 @@ dependencies:
- python=3.9
- cudatoolkit
- pip
- - pytorch=1.12.1
+ - pytorch=2.0.0
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
- - accelerate==0.16.0
+ - accelerate==0.20.3
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
@@ -25,7 +25,7 @@ dependencies:
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- - huggingface-hub==0.13.4
+ - huggingface-hub==0.18.0
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- matplotlib==3.7.0
@@ -40,7 +40,7 @@ dependencies:
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- - transformers==4.28.0
+ - transformers==4.32.0
- timm==0.6.13
- spacy==3.5.1
- webdataset==0.2.48
@@ -53,11 +53,10 @@ dependencies:
- iopath==0.1.10
- decord==0.6.0
- tenacity==8.2.2
- - peft
+ - peft==0.2.0
- pycocoevalcap
- sentence-transformers
- umap-learn
- notebook
- - gradio==3.24.1
- - gradio-client==0.0.8
+ - gradio==3.47.1
- wandb
diff --git a/examples_v2/2000x1372_wmkn_0012149409555.jpg b/examples_v2/2000x1372_wmkn_0012149409555.jpg
new file mode 100755
index 0000000..1250f7f
Binary files /dev/null and b/examples_v2/2000x1372_wmkn_0012149409555.jpg differ
diff --git a/examples_v2/KFC-20-for-20-Nuggets.jpg b/examples_v2/KFC-20-for-20-Nuggets.jpg
new file mode 100755
index 0000000..0ec641c
Binary files /dev/null and b/examples_v2/KFC-20-for-20-Nuggets.jpg differ
diff --git a/examples_v2/cockdial.png b/examples_v2/cockdial.png
new file mode 100755
index 0000000..935f98e
Binary files /dev/null and b/examples_v2/cockdial.png differ
diff --git a/examples_v2/float.png b/examples_v2/float.png
new file mode 100755
index 0000000..900dcb0
Binary files /dev/null and b/examples_v2/float.png differ
diff --git a/examples_v2/glip_test.jpg b/examples_v2/glip_test.jpg
new file mode 100755
index 0000000..f9198f2
Binary files /dev/null and b/examples_v2/glip_test.jpg differ
diff --git a/examples_v2/office.jpg b/examples_v2/office.jpg
new file mode 100755
index 0000000..e35bdc2
Binary files /dev/null and b/examples_v2/office.jpg differ
diff --git a/examples_v2/sofa.jpg b/examples_v2/sofa.jpg
new file mode 100755
index 0000000..8610591
Binary files /dev/null and b/examples_v2/sofa.jpg differ
diff --git a/examples_v2/thief.png b/examples_v2/thief.png
new file mode 100755
index 0000000..579ee52
Binary files /dev/null and b/examples_v2/thief.png differ
diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
index 7678814..9c27c78 100644
--- a/minigpt4/conversation/conversation.py
+++ b/minigpt4/conversation/conversation.py
@@ -1,10 +1,11 @@
import argparse
import time
+from threading import Thread
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
-from transformers import StoppingCriteria, StoppingCriteriaList
+from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import dataclasses
from enum import auto, Enum
@@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
class Chat:
- def __init__(self, model, vis_processor, device='cuda:0'):
+ def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
self.device = device
self.model = model
self.vis_processor = vis_processor
- stop_words_ids = [torch.tensor([835]).to(self.device),
- torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+
+ if stopping_criteria is not None:
+ self.stopping_criteria = stopping_criteria
+ else:
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
@@ -144,8 +148,8 @@ class Chat:
else:
conv.append_message(conv.roles[0], text)
- def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
- repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
@@ -154,10 +158,9 @@ class Chat:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
-
embs = embs[:, begin_idx:]
- outputs = self.model.llama_model.generate(
+ generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
@@ -169,18 +172,31 @@ class Chat:
length_penalty=length_penalty,
temperature=temperature,
)
- output_token = outputs[0]
- if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
- output_token = output_token[1:]
- if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
- output_token = output_token[1:]
- output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
+ return generation_kwargs
+
+ def answer(self, conv, img_list, **kargs):
+ generation_dict = self.answer_prepare(conv, img_list, **kargs)
+
+ output_token = self.model.llama_model.generate(**generation_dict)[0]
+ output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
+
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
+
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
- def upload_img(self, image, conv, img_list):
+ def stream_answer(self, conv, img_list, **kargs):
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
+ streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
+ generation_kwargs['streamer'] = streamer
+ thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
+ thread.start()
+ return streamer
+
+ def encode_img(self, img_list):
+ image = img_list[0]
+ img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@@ -194,9 +210,12 @@ class Chat:
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
+
+ def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "
")
+ img_list.append(image)
msg = "Received."
- # self.conv.append_message(self.conv.roles[1], msg)
+
return msg
def get_context_emb(self, conv, img_list):
@@ -209,7 +228,9 @@ class Chat:
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
- seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
+ print('debug device: ', self.device)
+ print('debug model device: ', self.model.device)
+ seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py
index ae0a3be..fd1d636 100644
--- a/minigpt4/models/base_model.py
+++ b/minigpt4/models/base_model.py
@@ -169,7 +169,7 @@ class BaseModel(nn.Module):
return visual_encoder, ln_vision
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
- **lora_kargs):
+ lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA')
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
llama_tokenizer.pad_token = "$$"
@@ -193,6 +193,7 @@ class BaseModel(nn.Module):
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
+ target_modules=lora_target_modules,
**lora_kargs
)
llama_model = get_peft_model(llama_model, loraconfig)