mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
adding length control and change the default hyperparameter of the conversation to avoid OOM in 3090.
This commit is contained in:
parent
3f59092f0c
commit
dadc0d7e69
7
demo.py
7
demo.py
@ -89,7 +89,12 @@ def gradio_ask(user_message, chatbot, chat_state):
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0]
|
||||
llm_message = chat.answer(conv=chat_state,
|
||||
img_list=img_list,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature,
|
||||
max_new_tokens=300,
|
||||
max_length=2000)[0]
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
@ -134,10 +134,19 @@ class Chat:
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0):
|
||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user