Merge pull request #404 from TsuTikgiau/main

fix the compatibility issue when low_resource is false in the inference
This commit is contained in:
Jun Chen 2023-10-30 13:26:42 -07:00 committed by GitHub
commit 751069f1c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 23 deletions

View File

@ -30,4 +30,6 @@ dependencies:
- gradio==3.47.1
- accelerate==0.20.3
- bitsandbytes==0.37.0
- scikit-image
- visual-genome
- wandb

View File

@ -151,7 +151,8 @@ class Chat:
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
prompt = conv.get_prompt()
embs = self.model.get_context_emb(prompt, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
@ -176,8 +177,7 @@ class Chat:
def answer(self, conv, img_list, **kargs):
generation_dict = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model.llama_model.generate(**generation_dict)[0]
output_token = self.model_generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###'
@ -190,10 +190,16 @@ class Chat:
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
thread.start()
return streamer
def model_generate(self, *args, **kwargs):
# for 8 bit and 16 bit compatibility
with self.model.maybe_autocast():
output = self.model.llama_model.generate(*args, **kwargs)
return output
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
@ -218,21 +224,3 @@ class Chat:
return msg
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
print('debug device: ', self.device)
print('debug model device: ', self.model.device)
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs

View File

@ -9,6 +9,8 @@ from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList
from minigpt4.conversation.conversation import StoppingCriteriaSub
class MiniGPTBase(BaseModel):
@ -314,7 +316,6 @@ class MiniGPTBase(BaseModel):
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@torch.no_grad()
def generate(
self,