From 8f738f096615c54fb5f3aa979865ab5ed65a4e12 Mon Sep 17 00:00:00 2001
From: Dan <dan.nelson8@gmail.com>
Date: Thu, 4 May 2023 20:40:18 +0000
Subject: [PATCH] working replicate demo

---
 README.md  |  2 +-
 cog.yaml   | 27 +++++++++++++++++
 predict.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 114 insertions(+), 1 deletion(-)
 create mode 100644 cog.yaml
 create mode 100644 predict.py

diff --git a/README.md b/README.md
index 7aa29f2..0df8da3 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
 
 **King Abdullah University of Science and Technology**
 
-<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>  <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
+<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>  <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>
 
 
 ## News
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..442f34d
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,27 @@
+build:
+  gpu: true
+  cuda: "11.3"
+  system_packages:
+    - "libgl1-mesa-glx"
+    - "libglib2.0-0"
+  python_version: "3.8"
+  python_packages:
+    - "torch==1.12.1"
+    - "torchvision"
+    - "transformers==4.28.1"
+    - "gradio==3.28.1"
+    - "omegaconf==2.1.2" 
+    - "iopath"
+    - "timm==0.6.13"
+    - "webdataset==0.2.48"
+    - "opencv-python==4.7.0.72"
+    - "tensorizer"
+    - "decord==0.6.0"
+    - "sentencepiece"
+
+  run: 
+    - "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
+    - "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
+    - "apt-get update && apt-get install google-cloud-cli"
+
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..8795d29
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,86 @@
+from cog import BasePredictor, Input, Path
+from minigpt4.common.config import Config
+import torch
+import argparse
+from PIL import Image
+
+from minigpt4.models import MiniGPT4
+from minigpt4.common.registry import registry
+from minigpt4.conversation.conversation import Chat, CONV_VISION
+
+import os
+
+# setting cache directory
+os.environ["TORCH_HOME"] = "/src/model_cache"
+
+
+class Predictor(BasePredictor):
+    def setup(self):
+        args = argparse.Namespace()
+        args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
+        args.gpu_id = 0
+        args.options = []
+
+        config = Config(args)
+
+        model = MiniGPT4.from_config(config.model_cfg).to("cuda")
+        vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
+        vis_processor = registry.get_processor_class(
+            vis_processor_cfg.name
+        ).from_config(vis_processor_cfg)
+        self.chat = Chat(model, vis_processor, device="cuda")
+
+    def predict(
+        self,
+        image: Path = Input(description="Image to discuss"),
+        prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
+        num_beams: int = Input(
+            description="Number of beams for beam search decoding",
+            default=3,
+            ge=1,
+            le=10,
+        ),
+        temperature: float = Input(
+            description="Temperature for generating tokens, lower = more predictable results",
+            default=1.0,
+            ge=0.01,
+            le=2.0,
+        ),
+        top_p: float = Input(
+            description="Sample from the top p percent most likely tokens",
+            default=0.9,
+            ge=0.0,
+            le=1.0,
+        ),
+        repetition_penalty: float = Input(
+            description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
+            default=1.0,
+            ge=0.01,
+            le=5,
+        ),
+        max_new_tokens: int = Input(
+            description="Maximum number of new tokens to generate", ge=1, default=3000
+        ),
+        max_length: int = Input(
+            description="Total length of prompt and output in tokens",
+            ge=1,
+            default=4000,
+        ),
+    ) -> str:
+        img_list = []
+        image = Image.open(image).convert("RGB")
+        with torch.inference_mode():
+            chat_state = CONV_VISION.copy()
+            self.chat.upload_img(image, chat_state, img_list)
+            self.chat.ask(prompt, chat_state)
+            answer = self.chat.answer(
+                conv=chat_state,
+                img_list=img_list,
+                num_beams=num_beams,
+                temperature=temperature,
+                max_new_tokens=max_new_tokens,
+                max_length=max_length,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+            )
+        return answer[0]