mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
working replicate demo
This commit is contained in:
parent
22d8888ca2
commit
8f738f0966
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
**King Abdullah University of Science and Technology**
|
**King Abdullah University of Science and Technology**
|
||||||
|
|
||||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>
|
||||||
|
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
27
cog.yaml
Normal file
27
cog.yaml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
build:
|
||||||
|
gpu: true
|
||||||
|
cuda: "11.3"
|
||||||
|
system_packages:
|
||||||
|
- "libgl1-mesa-glx"
|
||||||
|
- "libglib2.0-0"
|
||||||
|
python_version: "3.8"
|
||||||
|
python_packages:
|
||||||
|
- "torch==1.12.1"
|
||||||
|
- "torchvision"
|
||||||
|
- "transformers==4.28.1"
|
||||||
|
- "gradio==3.28.1"
|
||||||
|
- "omegaconf==2.1.2"
|
||||||
|
- "iopath"
|
||||||
|
- "timm==0.6.13"
|
||||||
|
- "webdataset==0.2.48"
|
||||||
|
- "opencv-python==4.7.0.72"
|
||||||
|
- "tensorizer"
|
||||||
|
- "decord==0.6.0"
|
||||||
|
- "sentencepiece"
|
||||||
|
|
||||||
|
run:
|
||||||
|
- "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
|
||||||
|
- "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
|
||||||
|
- "apt-get update && apt-get install google-cloud-cli"
|
||||||
|
|
||||||
|
predict: "predict.py:Predictor"
|
86
predict.py
Normal file
86
predict.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from cog import BasePredictor, Input, Path
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from minigpt4.models import MiniGPT4
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# setting cache directory
|
||||||
|
os.environ["TORCH_HOME"] = "/src/model_cache"
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor(BasePredictor):
|
||||||
|
def setup(self):
|
||||||
|
args = argparse.Namespace()
|
||||||
|
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
|
||||||
|
args.gpu_id = 0
|
||||||
|
args.options = []
|
||||||
|
|
||||||
|
config = Config(args)
|
||||||
|
|
||||||
|
model = MiniGPT4.from_config(config.model_cfg).to("cuda")
|
||||||
|
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||||
|
vis_processor = registry.get_processor_class(
|
||||||
|
vis_processor_cfg.name
|
||||||
|
).from_config(vis_processor_cfg)
|
||||||
|
self.chat = Chat(model, vis_processor, device="cuda")
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
image: Path = Input(description="Image to discuss"),
|
||||||
|
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
|
||||||
|
num_beams: int = Input(
|
||||||
|
description="Number of beams for beam search decoding",
|
||||||
|
default=3,
|
||||||
|
ge=1,
|
||||||
|
le=10,
|
||||||
|
),
|
||||||
|
temperature: float = Input(
|
||||||
|
description="Temperature for generating tokens, lower = more predictable results",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.01,
|
||||||
|
le=2.0,
|
||||||
|
),
|
||||||
|
top_p: float = Input(
|
||||||
|
description="Sample from the top p percent most likely tokens",
|
||||||
|
default=0.9,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
),
|
||||||
|
repetition_penalty: float = Input(
|
||||||
|
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.01,
|
||||||
|
le=5,
|
||||||
|
),
|
||||||
|
max_new_tokens: int = Input(
|
||||||
|
description="Maximum number of new tokens to generate", ge=1, default=3000
|
||||||
|
),
|
||||||
|
max_length: int = Input(
|
||||||
|
description="Total length of prompt and output in tokens",
|
||||||
|
ge=1,
|
||||||
|
default=4000,
|
||||||
|
),
|
||||||
|
) -> str:
|
||||||
|
img_list = []
|
||||||
|
image = Image.open(image).convert("RGB")
|
||||||
|
with torch.inference_mode():
|
||||||
|
chat_state = CONV_VISION.copy()
|
||||||
|
self.chat.upload_img(image, chat_state, img_list)
|
||||||
|
self.chat.ask(prompt, chat_state)
|
||||||
|
answer = self.chat.answer(
|
||||||
|
conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
num_beams=num_beams,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
max_length=max_length,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
return answer[0]
|
Loading…
Reference in New Issue
Block a user