mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 02:50:47 +00:00
adding length control and change the default hyperparameter of the conversation to avoid OOM in 3090.
This commit is contained in:
parent
3f59092f0c
commit
dadc0d7e69
7
demo.py
7
demo.py
@ -89,7 +89,12 @@ def gradio_ask(user_message, chatbot, chat_state):
|
|||||||
|
|
||||||
|
|
||||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||||
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0]
|
llm_message = chat.answer(conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
num_beams=num_beams,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=300,
|
||||||
|
max_length=2000)[0]
|
||||||
chatbot[-1][1] = llm_message
|
chatbot[-1][1] = llm_message
|
||||||
return chatbot, chat_state, img_list
|
return chatbot, chat_state, img_list
|
||||||
|
|
||||||
|
@ -134,10 +134,19 @@ class Chat:
|
|||||||
else:
|
else:
|
||||||
conv.append_message(conv.roles[0], text)
|
conv.append_message(conv.roles[0], text)
|
||||||
|
|
||||||
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
|
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0):
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
embs = self.get_context_emb(conv, img_list)
|
embs = self.get_context_emb(conv, img_list)
|
||||||
|
|
||||||
|
current_max_len = embs.shape[1] + max_new_tokens
|
||||||
|
if current_max_len - max_length > 0:
|
||||||
|
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||||
|
'The model will not see the contexts outside the range.')
|
||||||
|
begin_idx = max(0, current_max_len - max_length)
|
||||||
|
|
||||||
|
embs = embs[:, begin_idx:]
|
||||||
|
|
||||||
outputs = self.model.llama_model.generate(
|
outputs = self.model.llama_model.generate(
|
||||||
inputs_embeds=embs,
|
inputs_embeds=embs,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user