mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
107 lines
4.5 KiB
Python
107 lines
4.5 KiB
Python
import argparse
|
|
import time
|
|
from PIL import Image
|
|
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
|
import dataclasses
|
|
from enum import auto, Enum
|
|
from typing import List, Tuple, Any
|
|
|
|
from minigpt4.common.registry import registry
|
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria):
|
|
|
|
def __init__(self, stops=[], encounters=1):
|
|
super().__init__()
|
|
self.stops = stops
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
for stop in self.stops:
|
|
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class Chat:
|
|
def __init__(self, model, vis_processor, device='cuda:0'):
|
|
self.device = device
|
|
self.model = model
|
|
self.vis_processor = vis_processor
|
|
stop_words_ids = [torch.tensor([835]).to(self.device),
|
|
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
|
|
|
def answer(self, embs, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
|
|
|
# embs = self.get_context_emb(img_list)
|
|
|
|
current_max_len = embs.shape[1] + max_new_tokens
|
|
if current_max_len - max_length > 0:
|
|
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
|
'The model will not see the contexts outside the range.')
|
|
begin_idx = max(0, current_max_len - max_length)
|
|
embs = embs[:, begin_idx:]
|
|
|
|
outputs = self.model.llama_model.generate(
|
|
inputs_embeds=embs,
|
|
max_new_tokens=max_new_tokens,
|
|
stopping_criteria=self.stopping_criteria,
|
|
num_beams=num_beams,
|
|
do_sample=True,
|
|
min_length=min_length,
|
|
top_p=top_p,
|
|
repetition_penalty=repetition_penalty,
|
|
length_penalty=length_penalty,
|
|
temperature=temperature,
|
|
)
|
|
output_token = outputs[0]
|
|
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
|
output_token = output_token[1:]
|
|
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
|
output_token = output_token[1:]
|
|
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
|
output_text = output_text.split('Assistant:')[-1].strip()
|
|
return output_text, output_token.cpu().numpy()
|
|
|
|
def upload_img(self, image):
|
|
if isinstance(image, str): # is a image path
|
|
raw_image = Image.open(image).convert('RGB')
|
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
|
elif isinstance(image, Image.Image):
|
|
raw_image = image
|
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
|
elif isinstance(image, torch.Tensor):
|
|
if len(image.shape) == 3:
|
|
image = image.unsqueeze(0)
|
|
image = image.to(self.device)
|
|
image_emb, _ = self.model.encode_img(image)
|
|
return image_emb
|
|
|
|
def get_context_emb(self, text_list, img_list):
|
|
system = "Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions." + "###"
|
|
prompt = "Human" + ": " + "<Img><ImageHere></Img> " + text_list + "###"
|
|
prompt = system + prompt
|
|
|
|
prompt_segs = prompt.split('<ImageHere>')
|
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
|
seg_tokens = [
|
|
self.model.llama_tokenizer(
|
|
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
|
# only add bos to the first seg
|
|
for i, seg in enumerate(prompt_segs)
|
|
]
|
|
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
|
# [1, 42, 4096]
|
|
# [1, 13, 4096]
|
|
# print(seg_embs[:-1].shape)
|
|
# print(seg_embs[-1].shape)
|
|
mixed_embs = torch.cat([seg_embs[0], img_list, seg_embs[1]], dim=1)
|
|
# mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
return mixed_embs |