diff --git a/demo_v2.py b/demo_v2.py index d35bddc..62c65f4 100644 --- a/demo_v2.py +++ b/demo_v2.py @@ -549,7 +549,7 @@ with gr.Blocks() as demo: temperature = gr.Slider( minimum=0.1, - maximum=2.0, + maximum=1.5, value=1.0, step=0.1, interactive=True, diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 9c27c78..6e4a5c1 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -170,7 +170,7 @@ class Chat: top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, - temperature=temperature, + temperature=float(temperature), ) return generation_kwargs diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index 77c919a..c24c579 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -7,6 +7,7 @@ import torch.nn as nn from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel +from transformers import StoppingCriteria, StoppingCriteriaList @@ -365,8 +366,8 @@ class MiniGPTBase(BaseModel): do_sample=do_sample, min_length=min_length, top_p=top_p, - repetition_penalty=repetition_penalty - # stopping_criteria=stopping_criteria, + repetition_penalty=repetition_penalty, + stopping_criteria=stopping_criteria, ) answers = []