MiniGPT-4/predict.py
2023-05-18 00:04:47 +00:00

87 lines
2.9 KiB
Python

from cog import BasePredictor, Input, Path
from minigpt4.common.config import Config
import torch
import argparse
from PIL import Image
from minigpt4.models import MiniGPT4
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
import os
# setting cache directory
os.environ["TORCH_HOME"] = "/src/model_cache"
class Predictor(BasePredictor):
def setup(self):
args = argparse.Namespace()
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
args.gpu_id = 0
args.options = []
config = Config(args)
model = MiniGPT4.from_config(config.model_cfg).to("cuda")
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(
vis_processor_cfg.name
).from_config(vis_processor_cfg)
self.chat = Chat(model, vis_processor, device="cuda")
def predict(
self,
image: Path = Input(description="Image to discuss"),
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
num_beams: int = Input(
description="Number of beams for beam search decoding",
default=3,
ge=1,
le=10,
),
temperature: float = Input(
description="Temperature for generating tokens, lower = more predictable results",
default=1.0,
ge=0.01,
le=2.0,
),
top_p: float = Input(
description="Sample from the top p percent most likely tokens",
default=0.9,
ge=0.0,
le=1.0,
),
repetition_penalty: float = Input(
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
default=1.0,
ge=0.01,
le=5,
),
max_new_tokens: int = Input(
description="Maximum number of new tokens to generate", ge=1, default=3000
),
max_length: int = Input(
description="Total length of prompt and output in tokens",
ge=1,
default=4000,
),
) -> str:
img_list = []
image = Image.open(image).convert("RGB")
with torch.inference_mode():
chat_state = CONV_VISION.copy()
self.chat.upload_img(image, chat_state, img_list)
self.chat.ask(prompt, chat_state)
answer = self.chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_length=max_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return answer[0]