adding length control and change the default hyperparameter of the conversation to avoid OOM in 3090.

This commit is contained in:
Deyao Zhu 2023-04-18 22:04:50 +03:00
parent 3f59092f0c
commit dadc0d7e69
2 changed files with 17 additions and 3 deletions

View File

@ -89,7 +89,12 @@ def gradio_ask(user_message, chatbot, chat_state):
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0]
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list

View File

@ -134,10 +134,19 @@ class Chat:
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0):
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,