mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
update merging
This commit is contained in:
commit
bf605d5d74
@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
for stop in self.stops:
|
for stop in self.stops:
|
||||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
if torch.all(input_ids[:, -len(stop):] == stop).item():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -158,7 +158,8 @@ class Chat:
|
|||||||
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
embs = self.get_context_emb(conv, img_list)
|
prompt = conv.get_prompt()
|
||||||
|
embs = self.model.get_context_emb(prompt, img_list)
|
||||||
|
|
||||||
current_max_len = embs.shape[1] + max_new_tokens
|
current_max_len = embs.shape[1] + max_new_tokens
|
||||||
if current_max_len - max_length > 0:
|
if current_max_len - max_length > 0:
|
||||||
@ -183,8 +184,7 @@ class Chat:
|
|||||||
|
|
||||||
def answer(self, conv, img_list, **kargs):
|
def answer(self, conv, img_list, **kargs):
|
||||||
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
||||||
|
output_token = self.model_generate(**generation_dict)[0]
|
||||||
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
|
||||||
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
|
||||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||||
@ -197,10 +197,16 @@ class Chat:
|
|||||||
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||||
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||||
generation_kwargs['streamer'] = streamer
|
generation_kwargs['streamer'] = streamer
|
||||||
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
|
||||||
thread.start()
|
thread.start()
|
||||||
return streamer
|
return streamer
|
||||||
|
|
||||||
|
def model_generate(self, *args, **kwargs):
|
||||||
|
# for 8 bit and 16 bit compatibility
|
||||||
|
with self.model.maybe_autocast():
|
||||||
|
output = self.model.llama_model.generate(*args, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
def encode_img(self, img_list):
|
def encode_img(self, img_list):
|
||||||
image = img_list[0]
|
image = img_list[0]
|
||||||
img_list.pop(0)
|
img_list.pop(0)
|
||||||
@ -225,21 +231,3 @@ class Chat:
|
|||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def get_context_emb(self, conv, img_list):
|
|
||||||
prompt = conv.get_prompt()
|
|
||||||
prompt_segs = prompt.split('<ImageHere>')
|
|
||||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
|
||||||
seg_tokens = [
|
|
||||||
self.model.llama_tokenizer(
|
|
||||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
|
||||||
# only add bos to the first seg
|
|
||||||
for i, seg in enumerate(prompt_segs)
|
|
||||||
]
|
|
||||||
print('debug device: ', self.device)
|
|
||||||
print('debug model device: ', self.model.device)
|
|
||||||
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
|
||||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
|
||||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
||||||
return mixed_embs
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,11 @@ from minigpt4.models.base_model import BaseModel
|
|||||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
from minigpt4.conversation.conversation import StoppingCriteriaSub
|
from minigpt4.conversation.conversation import StoppingCriteriaSub
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
|
||||||
|
|
||||||
|
>>>>>>> upstream/main
|
||||||
|
|
||||||
class MiniGPTBase(BaseModel):
|
class MiniGPTBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -314,7 +319,6 @@ class MiniGPTBase(BaseModel):
|
|||||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user