Initial Version (might have some bugs)

This commit is contained in:
unknown 2023-05-18 13:52:05 +08:00
parent ec6d62eb32
commit d2db151df2
7 changed files with 203 additions and 15 deletions

View File

@ -0,0 +1 @@
# TODO: Finish the eval config of ImageBindGPT4

View File

@ -38,12 +38,12 @@ from imagebind.models.multimodal_projectors import create_projectors, create_pre
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
ModalityType = SimpleNamespace( ModalityType = SimpleNamespace(
VISION="Vision", VISION="vision",
TEXT="Text", TEXT="text",
AUDIO="Audio", AUDIO="audio",
THERMAL="Thermal", THERMAL="thermal",
DEPTH="Depth", DEPTH="depth",
IMU="Imu", IMU="imu",
) )

View File

@ -0,0 +1,29 @@
model:
arch: bind_gpt4
# Imagebind
freeze_imagebind: True
# Q-Former
freeze_qformer: True
num_query_token: 32
# Vicuna
llama_model: "/path/to/vicuna/weights/"
# generation configs
prompt: ""
preprocess:
vis_processor:
train:
name: "imagebind_vision_train"
image_size: 224
eval:
name: "imagebind_vision_eval"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
eval:
name: "imagebind_caption"

View File

@ -14,6 +14,7 @@ from minigpt4.models.base_model import BaseModel
from minigpt4.models.blip2 import Blip2Base from minigpt4.models.blip2 import Blip2Base
from minigpt4.models.mini_gpt4 import MiniGPT4 from minigpt4.models.mini_gpt4 import MiniGPT4
from minigpt4.processors.base_processor import BaseProcessor from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.models.bind_gpt4 import BindGPT4
__all__ = [ __all__ = [
@ -21,6 +22,7 @@ __all__ = [
"BaseModel", "BaseModel",
"Blip2Base", "Blip2Base",
"MiniGPT4", "MiniGPT4",
"BindGPT4"
] ]

View File

@ -90,7 +90,7 @@ class BindGPT4(BaseModel):
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
if prompt: if prompt:
batch_size = input_embeds.shape[0] batch_size = input_embeds.shape[0]
p_before, p_after = prompt.split('<{}Here>'.format(modality_name)) p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
p_before_tokens = self.llama_tokenizer( p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_after_tokens = self.llama_tokenizer( p_after_tokens = self.llama_tokenizer(
@ -110,7 +110,7 @@ class BindGPT4(BaseModel):
Other modalities will conflict with the pre-defined prompt and wrapping strategy. Other modalities will conflict with the pre-defined prompt and wrapping strategy.
""" """
embeds = self.encode_inputs(inputs) embeds = self.encode_inputs(inputs)
assert "Vision" in embeds assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
prompt = random.choice(self.prompt_list) prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt) img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)

View File

@ -6,19 +6,27 @@
""" """
from minigpt4.processors.base_processor import BaseProcessor from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.blip_processors import ( # from minigpt4.processors.blip_processors import (
Blip2ImageTrainProcessor, # Blip2ImageTrainProcessor,
Blip2ImageEvalProcessor, # Blip2ImageEvalProcessor,
BlipCaptionProcessor, # BlipCaptionProcessor,
# )
from minigpt4.processors.imagebind_processor import (
ImageBindCaptionProcessor,
ImageBindVisionTrainProcessor,
ImageBindVisionEvalProcessor
) )
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
__all__ = [ __all__ = [
"BaseProcessor", "BaseProcessor",
"Blip2ImageTrainProcessor", # "Blip2ImageTrainProcessor",
"Blip2ImageEvalProcessor", # "Blip2ImageEvalProcessor",
"BlipCaptionProcessor", # "BlipCaptionProcessor",
"ImageBindCaptionProcessor",
"ImageBindVisionTrainProcessor",
"ImageBindVisionEvalProcessor"
] ]

View File

@ -0,0 +1,148 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import re
from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.randaugment import RandomAugment
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class ImageBindVisionBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
super().__init__()
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
# Note: The config of caption processor keeps the same as BLIP2 / MiniGPT4
@registry.register_processor("imagebind_caption")
class ImageBindCaptionProcessor(BaseProcessor):
def __init__(self, prompt="", max_words=50):
# Note: Actually no prompts are used here.
super().__init__()
self.prompt = prompt
self.max_words = max_words
def __call__(self, caption):
caption = self.prompt + self.pre_caption(caption)
return caption
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
prompt = cfg.get("prompt", "")
max_words = cfg.get("max_words", 50)
return cls(prompt=prompt, max_words=max_words)
def pre_caption(self, caption):
caption = re.sub(
r"([.!\"()*#:;~])",
" ",
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
" ",
caption,
)
caption = caption.rstrip("\n")
caption = caption.strip(" ")
# truncate caption
caption_words = caption.split(" ")
if len(caption_words) > self.max_words:
caption = " ".join(caption_words[: self.max_words])
return caption
# Note: The training config of vision processor keeps the same as BLIP2 / MiniGPT4
@registry.register_processor("imagebind_vision_train")
class ImageBindVisionTrainProcessor(ImageBindVisionBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
)
# Changed.
@registry.register_processor("imagebind_vision_eval")
class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
image_size, interpolation=InterpolationMode.BICUBIC
),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)