mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 01:50:47 +00:00
fix the compatibility issue when low-resource is false in the inference mode. Small refator in the conversation.py to remove repeated codes
This commit is contained in:
parent
fe1d79c66a
commit
ada58a6f21
@ -30,4 +30,6 @@ dependencies:
|
||||
- gradio==3.47.1
|
||||
- accelerate==0.20.3
|
||||
- bitsandbytes==0.37.0
|
||||
- scikit-image
|
||||
- visual-genome
|
||||
- wandb
|
||||
|
@ -151,7 +151,8 @@ class Chat:
|
||||
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
prompt = conv.get_prompt()
|
||||
embs = self.model.get_context_emb(prompt, img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
@ -176,8 +177,7 @@ class Chat:
|
||||
|
||||
def answer(self, conv, img_list, **kargs):
|
||||
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
||||
|
||||
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||
output_token = self.model_generate(**generation_dict)[0]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
@ -190,10 +190,16 @@ class Chat:
|
||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||
generation_kwargs['streamer'] = streamer
|
||||
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||
thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
return streamer
|
||||
|
||||
def model_generate(self, *args, **kwargs):
|
||||
# for 8 bit and 16 bit compatibility
|
||||
with self.model.maybe_autocast():
|
||||
output = self.model.llama_model.generate(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def encode_img(self, img_list):
|
||||
image = img_list[0]
|
||||
img_list.pop(0)
|
||||
@ -218,21 +224,3 @@ class Chat:
|
||||
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
print('debug device: ', self.device)
|
||||
print('debug model device: ', self.model.device)
|
||||
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
|
||||
|
@ -9,6 +9,8 @@ from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from minigpt4.conversation.conversation import StoppingCriteriaSub
|
||||
|
||||
|
||||
|
||||
class MiniGPTBase(BaseModel):
|
||||
@ -314,7 +316,6 @@ class MiniGPTBase(BaseModel):
|
||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||
return embeds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user