diff --git a/README.md b/README.md index 9fca286..3fc9ba6 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny *equal contribution - [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) + [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) *King Abdullah University of Science and Technology* diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..442f34d --- /dev/null +++ b/cog.yaml @@ -0,0 +1,27 @@ +build: + gpu: true + cuda: "11.3" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_version: "3.8" + python_packages: + - "torch==1.12.1" + - "torchvision" + - "transformers==4.28.1" + - "gradio==3.28.1" + - "omegaconf==2.1.2" + - "iopath" + - "timm==0.6.13" + - "webdataset==0.2.48" + - "opencv-python==4.7.0.72" + - "tensorizer" + - "decord==0.6.0" + - "sentencepiece" + + run: + - "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list" + - "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -" + - "apt-get update && apt-get install google-cloud-cli" + +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..8795d29 --- /dev/null +++ b/predict.py @@ -0,0 +1,86 @@ +from cog import BasePredictor, Input, Path +from minigpt4.common.config import Config +import torch +import argparse +from PIL import Image + +from minigpt4.models import MiniGPT4 +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Chat, CONV_VISION + +import os + +# setting cache directory +os.environ["TORCH_HOME"] = "/src/model_cache" + + +class Predictor(BasePredictor): + def setup(self): + args = argparse.Namespace() + args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml" + args.gpu_id = 0 + args.options = [] + + config = Config(args) + + model = MiniGPT4.from_config(config.model_cfg).to("cuda") + vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train + vis_processor = registry.get_processor_class( + vis_processor_cfg.name + ).from_config(vis_processor_cfg) + self.chat = Chat(model, vis_processor, device="cuda") + + def predict( + self, + image: Path = Input(description="Image to discuss"), + prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"), + num_beams: int = Input( + description="Number of beams for beam search decoding", + default=3, + ge=1, + le=10, + ), + temperature: float = Input( + description="Temperature for generating tokens, lower = more predictable results", + default=1.0, + ge=0.01, + le=2.0, + ), + top_p: float = Input( + description="Sample from the top p percent most likely tokens", + default=0.9, + ge=0.0, + le=1.0, + ), + repetition_penalty: float = Input( + description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.", + default=1.0, + ge=0.01, + le=5, + ), + max_new_tokens: int = Input( + description="Maximum number of new tokens to generate", ge=1, default=3000 + ), + max_length: int = Input( + description="Total length of prompt and output in tokens", + ge=1, + default=4000, + ), + ) -> str: + img_list = [] + image = Image.open(image).convert("RGB") + with torch.inference_mode(): + chat_state = CONV_VISION.copy() + self.chat.upload_img(image, chat_state, img_list) + self.chat.ask(prompt, chat_state) + answer = self.chat.answer( + conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=max_new_tokens, + max_length=max_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + return answer[0]