diff --git a/demo.py b/demo.py index b074ed2..a79b3a8 100644 --- a/demo.py +++ b/demo.py @@ -89,7 +89,12 @@ def gradio_ask(user_message, chatbot, chat_state): def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): - llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0] + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] chatbot[-1][1] = llm_message return chatbot, chat_state, img_list diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 84aea9a..7cd50bb 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -134,10 +134,19 @@ class Chat: else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0): + def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + outputs = self.model.llama_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens,