diff --git a/README.md b/README.md
index 9fca286..3fc9ba6 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
*equal contribution
-
[](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
+
[](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
*King Abdullah University of Science and Technology*
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..442f34d
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,27 @@
+build:
+ gpu: true
+ cuda: "11.3"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_version: "3.8"
+ python_packages:
+ - "torch==1.12.1"
+ - "torchvision"
+ - "transformers==4.28.1"
+ - "gradio==3.28.1"
+ - "omegaconf==2.1.2"
+ - "iopath"
+ - "timm==0.6.13"
+ - "webdataset==0.2.48"
+ - "opencv-python==4.7.0.72"
+ - "tensorizer"
+ - "decord==0.6.0"
+ - "sentencepiece"
+
+ run:
+ - "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
+ - "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
+ - "apt-get update && apt-get install google-cloud-cli"
+
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..8795d29
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,86 @@
+from cog import BasePredictor, Input, Path
+from minigpt4.common.config import Config
+import torch
+import argparse
+from PIL import Image
+
+from minigpt4.models import MiniGPT4
+from minigpt4.common.registry import registry
+from minigpt4.conversation.conversation import Chat, CONV_VISION
+
+import os
+
+# setting cache directory
+os.environ["TORCH_HOME"] = "/src/model_cache"
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+ args = argparse.Namespace()
+ args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
+ args.gpu_id = 0
+ args.options = []
+
+ config = Config(args)
+
+ model = MiniGPT4.from_config(config.model_cfg).to("cuda")
+ vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
+ vis_processor = registry.get_processor_class(
+ vis_processor_cfg.name
+ ).from_config(vis_processor_cfg)
+ self.chat = Chat(model, vis_processor, device="cuda")
+
+ def predict(
+ self,
+ image: Path = Input(description="Image to discuss"),
+ prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
+ num_beams: int = Input(
+ description="Number of beams for beam search decoding",
+ default=3,
+ ge=1,
+ le=10,
+ ),
+ temperature: float = Input(
+ description="Temperature for generating tokens, lower = more predictable results",
+ default=1.0,
+ ge=0.01,
+ le=2.0,
+ ),
+ top_p: float = Input(
+ description="Sample from the top p percent most likely tokens",
+ default=0.9,
+ ge=0.0,
+ le=1.0,
+ ),
+ repetition_penalty: float = Input(
+ description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
+ default=1.0,
+ ge=0.01,
+ le=5,
+ ),
+ max_new_tokens: int = Input(
+ description="Maximum number of new tokens to generate", ge=1, default=3000
+ ),
+ max_length: int = Input(
+ description="Total length of prompt and output in tokens",
+ ge=1,
+ default=4000,
+ ),
+ ) -> str:
+ img_list = []
+ image = Image.open(image).convert("RGB")
+ with torch.inference_mode():
+ chat_state = CONV_VISION.copy()
+ self.chat.upload_img(image, chat_state, img_list)
+ self.chat.ask(prompt, chat_state)
+ answer = self.chat.answer(
+ conv=chat_state,
+ img_list=img_list,
+ num_beams=num_beams,
+ temperature=temperature,
+ max_new_tokens=max_new_tokens,
+ max_length=max_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+ return answer[0]