diff --git a/demo.py b/demo.py index 8ceea3f..5445bfb 100644 --- a/demo.py +++ b/demo.py @@ -113,8 +113,8 @@ with gr.Blocks() as demo: num_beams = gr.Slider( minimum=1, - maximum=16, - value=5, + maximum=10, + value=1, step=1, interactive=True, label="beam search numbers)", diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 1a1221e..84aea9a 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -134,7 +134,7 @@ class Chat: else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9, + def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) @@ -142,7 +142,7 @@ class Chat: inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, - #num_beams=num_beams, + num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p,