From d2db151df2aa4f54d173969d2d690457e4b489e8 Mon Sep 17 00:00:00 2001
From: unknown <913556700@qq.com>
Date: Thu, 18 May 2023 13:52:05 +0800
Subject: [PATCH] Initial Version (might have some bugs)

---
 eval_configs/bindgpt4_eval.yaml            |   1 +
 imagebind/models/image_bind.py             |  12 +-
 minigpt4/configs/models/bindgpt4.yaml      |  29 ++++
 minigpt4/models/__init__.py                |   2 +
 minigpt4/models/bind_gpt4.py               |   4 +-
 minigpt4/processors/__init__.py            |  22 ++-
 minigpt4/processors/imagebind_processor.py | 148 +++++++++++++++++++++
 7 files changed, 203 insertions(+), 15 deletions(-)
 create mode 100644 minigpt4/processors/imagebind_processor.py

diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml
index e69de29..3a28b04 100644
--- a/eval_configs/bindgpt4_eval.yaml
+++ b/eval_configs/bindgpt4_eval.yaml
@@ -0,0 +1 @@
+# TODO: Finish the eval config of ImageBindGPT4
\ No newline at end of file
diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py
index 83cd641..6c6431c 100644
--- a/imagebind/models/image_bind.py
+++ b/imagebind/models/image_bind.py
@@ -38,12 +38,12 @@ from imagebind.models.multimodal_projectors import create_projectors, create_pre
 from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
 
 ModalityType = SimpleNamespace(
-    VISION="Vision",
-    TEXT="Text",
-    AUDIO="Audio",
-    THERMAL="Thermal",
-    DEPTH="Depth",
-    IMU="Imu",
+    VISION="vision",
+    TEXT="text",
+    AUDIO="audio",
+    THERMAL="thermal",
+    DEPTH="depth",
+    IMU="imu",
 )
 
 
diff --git a/minigpt4/configs/models/bindgpt4.yaml b/minigpt4/configs/models/bindgpt4.yaml
index e69de29..3d436ec 100644
--- a/minigpt4/configs/models/bindgpt4.yaml
+++ b/minigpt4/configs/models/bindgpt4.yaml
@@ -0,0 +1,29 @@
+model:
+  arch: bind_gpt4
+
+  # Imagebind
+  freeze_imagebind: True
+
+  # Q-Former
+  freeze_qformer: True
+  num_query_token: 32
+
+  # Vicuna
+  llama_model: "/path/to/vicuna/weights/"
+
+  # generation configs
+  prompt: ""
+
+preprocess:
+    vis_processor:
+        train:
+          name: "imagebind_vision_train"
+          image_size: 224
+        eval:
+          name: "imagebind_vision_eval"
+          image_size: 224
+    text_processor:
+        train:
+          name: "imagebind_caption"
+        eval:
+          name: "imagebind_caption"
diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py
index 54acd24..0b15c29 100644
--- a/minigpt4/models/__init__.py
+++ b/minigpt4/models/__init__.py
@@ -14,6 +14,7 @@ from minigpt4.models.base_model import BaseModel
 from minigpt4.models.blip2 import Blip2Base
 from minigpt4.models.mini_gpt4 import MiniGPT4
 from minigpt4.processors.base_processor import BaseProcessor
+from minigpt4.models.bind_gpt4 import BindGPT4
 
 
 __all__ = [
@@ -21,6 +22,7 @@ __all__ = [
     "BaseModel",
     "Blip2Base",
     "MiniGPT4",
+    "BindGPT4"
 ]
 
 
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
index 605d2f2..8a56b98 100644
--- a/minigpt4/models/bind_gpt4.py
+++ b/minigpt4/models/bind_gpt4.py
@@ -90,7 +90,7 @@ class BindGPT4(BaseModel):
         attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
         if prompt:
             batch_size = input_embeds.shape[0]
-            p_before, p_after = prompt.split('<{}Here>'.format(modality_name))
+            p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
             p_before_tokens = self.llama_tokenizer(
                 p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
             p_after_tokens = self.llama_tokenizer(
@@ -110,7 +110,7 @@ class BindGPT4(BaseModel):
             Other modalities will conflict with the pre-defined prompt and wrapping strategy.
         """
         embeds = self.encode_inputs(inputs)
-        assert "Vision" in embeds
+        assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
         prompt = random.choice(self.prompt_list)
         img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
 
diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py
index e560eaa..3d12f3b 100644
--- a/minigpt4/processors/__init__.py
+++ b/minigpt4/processors/__init__.py
@@ -6,19 +6,27 @@
 """
 
 from minigpt4.processors.base_processor import BaseProcessor
-from minigpt4.processors.blip_processors import (
-    Blip2ImageTrainProcessor,
-    Blip2ImageEvalProcessor,
-    BlipCaptionProcessor,
+# from minigpt4.processors.blip_processors import (
+#     Blip2ImageTrainProcessor,
+#     Blip2ImageEvalProcessor,
+#     BlipCaptionProcessor,
+# )
+from minigpt4.processors.imagebind_processor import (
+    ImageBindCaptionProcessor,
+    ImageBindVisionTrainProcessor,
+    ImageBindVisionEvalProcessor
 )
 
 from minigpt4.common.registry import registry
 
 __all__ = [
     "BaseProcessor",
-    "Blip2ImageTrainProcessor",
-    "Blip2ImageEvalProcessor",
-    "BlipCaptionProcessor",
+    # "Blip2ImageTrainProcessor",
+    # "Blip2ImageEvalProcessor",
+    # "BlipCaptionProcessor",
+    "ImageBindCaptionProcessor",
+    "ImageBindVisionTrainProcessor",
+    "ImageBindVisionEvalProcessor"
 ]
 
 
diff --git a/minigpt4/processors/imagebind_processor.py b/minigpt4/processors/imagebind_processor.py
new file mode 100644
index 0000000..4e13560
--- /dev/null
+++ b/minigpt4/processors/imagebind_processor.py
@@ -0,0 +1,148 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import re
+
+from minigpt4.common.registry import registry
+from minigpt4.processors.base_processor import BaseProcessor
+from minigpt4.processors.randaugment import RandomAugment
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class ImageBindVisionBaseProcessor(BaseProcessor):
+    def __init__(self, mean=None, std=None):
+        super().__init__()
+        if mean is None:
+            mean = (0.48145466, 0.4578275, 0.40821073)
+        if std is None:
+            std = (0.26862954, 0.26130258, 0.27577711)
+
+        self.normalize = transforms.Normalize(mean, std)
+
+
+# Note: The config of caption processor keeps the same as BLIP2 / MiniGPT4
+@registry.register_processor("imagebind_caption")
+class ImageBindCaptionProcessor(BaseProcessor):
+    def __init__(self, prompt="", max_words=50):
+        # Note: Actually no prompts are used here.
+        super().__init__()
+        self.prompt = prompt
+        self.max_words = max_words
+
+    def __call__(self, caption):
+        caption = self.prompt + self.pre_caption(caption)
+
+        return caption
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        prompt = cfg.get("prompt", "")
+        max_words = cfg.get("max_words", 50)
+
+        return cls(prompt=prompt, max_words=max_words)
+
+    def pre_caption(self, caption):
+        caption = re.sub(
+            r"([.!\"()*#:;~])",
+            " ",
+            caption.lower(),
+        )
+        caption = re.sub(
+            r"\s{2,}",
+            " ",
+            caption,
+        )
+        caption = caption.rstrip("\n")
+        caption = caption.strip(" ")
+
+        # truncate caption
+        caption_words = caption.split(" ")
+        if len(caption_words) > self.max_words:
+            caption = " ".join(caption_words[: self.max_words])
+
+        return caption
+
+
+# Note: The training config of vision processor keeps the same as BLIP2 / MiniGPT4
+@registry.register_processor("imagebind_vision_train")
+class ImageBindVisionTrainProcessor(ImageBindVisionBaseProcessor):
+    def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
+        super().__init__(mean=mean, std=std)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.RandomResizedCrop(
+                    image_size,
+                    scale=(min_scale, max_scale),
+                    interpolation=InterpolationMode.BICUBIC,
+                ),
+                transforms.ToTensor(),
+                self.normalize,
+            ]
+        )
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", 224)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        min_scale = cfg.get("min_scale", 0.5)
+        max_scale = cfg.get("max_scale", 1.0)
+
+        return cls(
+            image_size=image_size,
+            mean=mean,
+            std=std,
+            min_scale=min_scale,
+            max_scale=max_scale,
+        )
+
+
+# Changed.
+@registry.register_processor("imagebind_vision_eval")
+class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
+    def __init__(self, image_size=224, mean=None, std=None):
+        super().__init__(mean=mean, std=std)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(
+                    image_size, interpolation=InterpolationMode.BICUBIC
+                ),
+                transforms.CenterCrop(image_size),
+                transforms.ToTensor(),
+                self.normalize,
+            ]
+        )
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", 224)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        return cls(image_size=image_size, mean=mean, std=std)
\ No newline at end of file