working replicate demo

This commit is contained in:
Dan 2023-05-04 20:40:18 +00:00
parent 22d8888ca2
commit 8f738f0966
3 changed files with 114 additions and 1 deletions

View File

@ -3,7 +3,7 @@
**King Abdullah University of Science and Technology**
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>
## News

27
cog.yaml Normal file
View File

@ -0,0 +1,27 @@
build:
gpu: true
cuda: "11.3"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_version: "3.8"
python_packages:
- "torch==1.12.1"
- "torchvision"
- "transformers==4.28.1"
- "gradio==3.28.1"
- "omegaconf==2.1.2"
- "iopath"
- "timm==0.6.13"
- "webdataset==0.2.48"
- "opencv-python==4.7.0.72"
- "tensorizer"
- "decord==0.6.0"
- "sentencepiece"
run:
- "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
- "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
- "apt-get update && apt-get install google-cloud-cli"
predict: "predict.py:Predictor"

86
predict.py Normal file
View File

@ -0,0 +1,86 @@
from cog import BasePredictor, Input, Path
from minigpt4.common.config import Config
import torch
import argparse
from PIL import Image
from minigpt4.models import MiniGPT4
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
import os
# setting cache directory
os.environ["TORCH_HOME"] = "/src/model_cache"
class Predictor(BasePredictor):
def setup(self):
args = argparse.Namespace()
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
args.gpu_id = 0
args.options = []
config = Config(args)
model = MiniGPT4.from_config(config.model_cfg).to("cuda")
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(
vis_processor_cfg.name
).from_config(vis_processor_cfg)
self.chat = Chat(model, vis_processor, device="cuda")
def predict(
self,
image: Path = Input(description="Image to discuss"),
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
num_beams: int = Input(
description="Number of beams for beam search decoding",
default=3,
ge=1,
le=10,
),
temperature: float = Input(
description="Temperature for generating tokens, lower = more predictable results",
default=1.0,
ge=0.01,
le=2.0,
),
top_p: float = Input(
description="Sample from the top p percent most likely tokens",
default=0.9,
ge=0.0,
le=1.0,
),
repetition_penalty: float = Input(
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
default=1.0,
ge=0.01,
le=5,
),
max_new_tokens: int = Input(
description="Maximum number of new tokens to generate", ge=1, default=3000
),
max_length: int = Input(
description="Total length of prompt and output in tokens",
ge=1,
default=4000,
),
) -> str:
img_list = []
image = Image.open(image).convert("RGB")
with torch.inference_mode():
chat_state = CONV_VISION.copy()
self.chat.upload_img(image, chat_state, img_list)
self.chat.ask(prompt, chat_state)
answer = self.chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_length=max_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return answer[0]