mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
Merge 8f738f0966
into d94738a762
This commit is contained in:
commit
ec067f0daf
@ -15,7 +15,7 @@ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
|
||||
|
||||
*equal contribution
|
||||
|
||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>
|
||||
|
||||
*King Abdullah University of Science and Technology*
|
||||
|
||||
|
27
cog.yaml
Normal file
27
cog.yaml
Normal file
@ -0,0 +1,27 @@
|
||||
build:
|
||||
gpu: true
|
||||
cuda: "11.3"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_version: "3.8"
|
||||
python_packages:
|
||||
- "torch==1.12.1"
|
||||
- "torchvision"
|
||||
- "transformers==4.28.1"
|
||||
- "gradio==3.28.1"
|
||||
- "omegaconf==2.1.2"
|
||||
- "iopath"
|
||||
- "timm==0.6.13"
|
||||
- "webdataset==0.2.48"
|
||||
- "opencv-python==4.7.0.72"
|
||||
- "tensorizer"
|
||||
- "decord==0.6.0"
|
||||
- "sentencepiece"
|
||||
|
||||
run:
|
||||
- "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
|
||||
- "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
|
||||
- "apt-get update && apt-get install google-cloud-cli"
|
||||
|
||||
predict: "predict.py:Predictor"
|
86
predict.py
Normal file
86
predict.py
Normal file
@ -0,0 +1,86 @@
|
||||
from cog import BasePredictor, Input, Path
|
||||
from minigpt4.common.config import Config
|
||||
import torch
|
||||
import argparse
|
||||
from PIL import Image
|
||||
|
||||
from minigpt4.models import MiniGPT4
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||
|
||||
import os
|
||||
|
||||
# setting cache directory
|
||||
os.environ["TORCH_HOME"] = "/src/model_cache"
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
args = argparse.Namespace()
|
||||
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
|
||||
args.gpu_id = 0
|
||||
args.options = []
|
||||
|
||||
config = Config(args)
|
||||
|
||||
model = MiniGPT4.from_config(config.model_cfg).to("cuda")
|
||||
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(
|
||||
vis_processor_cfg.name
|
||||
).from_config(vis_processor_cfg)
|
||||
self.chat = Chat(model, vis_processor, device="cuda")
|
||||
|
||||
def predict(
|
||||
self,
|
||||
image: Path = Input(description="Image to discuss"),
|
||||
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
|
||||
num_beams: int = Input(
|
||||
description="Number of beams for beam search decoding",
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
),
|
||||
temperature: float = Input(
|
||||
description="Temperature for generating tokens, lower = more predictable results",
|
||||
default=1.0,
|
||||
ge=0.01,
|
||||
le=2.0,
|
||||
),
|
||||
top_p: float = Input(
|
||||
description="Sample from the top p percent most likely tokens",
|
||||
default=0.9,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
),
|
||||
repetition_penalty: float = Input(
|
||||
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
|
||||
default=1.0,
|
||||
ge=0.01,
|
||||
le=5,
|
||||
),
|
||||
max_new_tokens: int = Input(
|
||||
description="Maximum number of new tokens to generate", ge=1, default=3000
|
||||
),
|
||||
max_length: int = Input(
|
||||
description="Total length of prompt and output in tokens",
|
||||
ge=1,
|
||||
default=4000,
|
||||
),
|
||||
) -> str:
|
||||
img_list = []
|
||||
image = Image.open(image).convert("RGB")
|
||||
with torch.inference_mode():
|
||||
chat_state = CONV_VISION.copy()
|
||||
self.chat.upload_img(image, chat_state, img_list)
|
||||
self.chat.ask(prompt, chat_state)
|
||||
answer = self.chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
max_length=max_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
return answer[0]
|
Loading…
Reference in New Issue
Block a user