This commit is contained in:
Dan Nelson 2024-11-19 23:11:59 +00:00 committed by GitHub
commit ec067f0daf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 1 deletions

View File

@ -15,7 +15,7 @@ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
*equal contribution
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>
*King Abdullah University of Science and Technology*

27
cog.yaml Normal file
View File

@ -0,0 +1,27 @@
build:
gpu: true
cuda: "11.3"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_version: "3.8"
python_packages:
- "torch==1.12.1"
- "torchvision"
- "transformers==4.28.1"
- "gradio==3.28.1"
- "omegaconf==2.1.2"
- "iopath"
- "timm==0.6.13"
- "webdataset==0.2.48"
- "opencv-python==4.7.0.72"
- "tensorizer"
- "decord==0.6.0"
- "sentencepiece"
run:
- "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
- "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
- "apt-get update && apt-get install google-cloud-cli"
predict: "predict.py:Predictor"

86
predict.py Normal file
View File

@ -0,0 +1,86 @@
from cog import BasePredictor, Input, Path
from minigpt4.common.config import Config
import torch
import argparse
from PIL import Image
from minigpt4.models import MiniGPT4
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
import os
# setting cache directory
os.environ["TORCH_HOME"] = "/src/model_cache"
class Predictor(BasePredictor):
def setup(self):
args = argparse.Namespace()
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
args.gpu_id = 0
args.options = []
config = Config(args)
model = MiniGPT4.from_config(config.model_cfg).to("cuda")
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(
vis_processor_cfg.name
).from_config(vis_processor_cfg)
self.chat = Chat(model, vis_processor, device="cuda")
def predict(
self,
image: Path = Input(description="Image to discuss"),
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
num_beams: int = Input(
description="Number of beams for beam search decoding",
default=3,
ge=1,
le=10,
),
temperature: float = Input(
description="Temperature for generating tokens, lower = more predictable results",
default=1.0,
ge=0.01,
le=2.0,
),
top_p: float = Input(
description="Sample from the top p percent most likely tokens",
default=0.9,
ge=0.0,
le=1.0,
),
repetition_penalty: float = Input(
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
default=1.0,
ge=0.01,
le=5,
),
max_new_tokens: int = Input(
description="Maximum number of new tokens to generate", ge=1, default=3000
),
max_length: int = Input(
description="Total length of prompt and output in tokens",
ge=1,
default=4000,
),
) -> str:
img_list = []
image = Image.open(image).convert("RGB")
with torch.inference_mode():
chat_state = CONV_VISION.copy()
self.chat.upload_img(image, chat_state, img_list)
self.chat.ask(prompt, chat_state)
answer = self.chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_length=max_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return answer[0]