mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
Initial Version (might have some bugs)
This commit is contained in:
parent
ec6d62eb32
commit
d2db151df2
@ -0,0 +1 @@
|
|||||||
|
# TODO: Finish the eval config of ImageBindGPT4
|
@ -38,12 +38,12 @@ from imagebind.models.multimodal_projectors import create_projectors, create_pre
|
|||||||
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
|
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
|
||||||
|
|
||||||
ModalityType = SimpleNamespace(
|
ModalityType = SimpleNamespace(
|
||||||
VISION="Vision",
|
VISION="vision",
|
||||||
TEXT="Text",
|
TEXT="text",
|
||||||
AUDIO="Audio",
|
AUDIO="audio",
|
||||||
THERMAL="Thermal",
|
THERMAL="thermal",
|
||||||
DEPTH="Depth",
|
DEPTH="depth",
|
||||||
IMU="Imu",
|
IMU="imu",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
model:
|
||||||
|
arch: bind_gpt4
|
||||||
|
|
||||||
|
# Imagebind
|
||||||
|
freeze_imagebind: True
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
freeze_qformer: True
|
||||||
|
num_query_token: 32
|
||||||
|
|
||||||
|
# Vicuna
|
||||||
|
llama_model: "/path/to/vicuna/weights/"
|
||||||
|
|
||||||
|
# generation configs
|
||||||
|
prompt: ""
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_vision_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "imagebind_vision_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
eval:
|
||||||
|
name: "imagebind_caption"
|
@ -14,6 +14,7 @@ from minigpt4.models.base_model import BaseModel
|
|||||||
from minigpt4.models.blip2 import Blip2Base
|
from minigpt4.models.blip2 import Blip2Base
|
||||||
from minigpt4.models.mini_gpt4 import MiniGPT4
|
from minigpt4.models.mini_gpt4 import MiniGPT4
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
from minigpt4.models.bind_gpt4 import BindGPT4
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -21,6 +22,7 @@ __all__ = [
|
|||||||
"BaseModel",
|
"BaseModel",
|
||||||
"Blip2Base",
|
"Blip2Base",
|
||||||
"MiniGPT4",
|
"MiniGPT4",
|
||||||
|
"BindGPT4"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ class BindGPT4(BaseModel):
|
|||||||
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
||||||
if prompt:
|
if prompt:
|
||||||
batch_size = input_embeds.shape[0]
|
batch_size = input_embeds.shape[0]
|
||||||
p_before, p_after = prompt.split('<{}Here>'.format(modality_name))
|
p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
|
||||||
p_before_tokens = self.llama_tokenizer(
|
p_before_tokens = self.llama_tokenizer(
|
||||||
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
||||||
p_after_tokens = self.llama_tokenizer(
|
p_after_tokens = self.llama_tokenizer(
|
||||||
@ -110,7 +110,7 @@ class BindGPT4(BaseModel):
|
|||||||
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
||||||
"""
|
"""
|
||||||
embeds = self.encode_inputs(inputs)
|
embeds = self.encode_inputs(inputs)
|
||||||
assert "Vision" in embeds
|
assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
|
||||||
prompt = random.choice(self.prompt_list)
|
prompt = random.choice(self.prompt_list)
|
||||||
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
||||||
|
|
||||||
|
@ -6,19 +6,27 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
from minigpt4.processors.blip_processors import (
|
# from minigpt4.processors.blip_processors import (
|
||||||
Blip2ImageTrainProcessor,
|
# Blip2ImageTrainProcessor,
|
||||||
Blip2ImageEvalProcessor,
|
# Blip2ImageEvalProcessor,
|
||||||
BlipCaptionProcessor,
|
# BlipCaptionProcessor,
|
||||||
|
# )
|
||||||
|
from minigpt4.processors.imagebind_processor import (
|
||||||
|
ImageBindCaptionProcessor,
|
||||||
|
ImageBindVisionTrainProcessor,
|
||||||
|
ImageBindVisionEvalProcessor
|
||||||
)
|
)
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
"Blip2ImageTrainProcessor",
|
# "Blip2ImageTrainProcessor",
|
||||||
"Blip2ImageEvalProcessor",
|
# "Blip2ImageEvalProcessor",
|
||||||
"BlipCaptionProcessor",
|
# "BlipCaptionProcessor",
|
||||||
|
"ImageBindCaptionProcessor",
|
||||||
|
"ImageBindVisionTrainProcessor",
|
||||||
|
"ImageBindVisionEvalProcessor"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
148
minigpt4/processors/imagebind_processor.py
Normal file
148
minigpt4/processors/imagebind_processor.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
from minigpt4.processors.randaugment import RandomAugment
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBindVisionBaseProcessor(BaseProcessor):
|
||||||
|
def __init__(self, mean=None, std=None):
|
||||||
|
super().__init__()
|
||||||
|
if mean is None:
|
||||||
|
mean = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
if std is None:
|
||||||
|
std = (0.26862954, 0.26130258, 0.27577711)
|
||||||
|
|
||||||
|
self.normalize = transforms.Normalize(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The config of caption processor keeps the same as BLIP2 / MiniGPT4
|
||||||
|
@registry.register_processor("imagebind_caption")
|
||||||
|
class ImageBindCaptionProcessor(BaseProcessor):
|
||||||
|
def __init__(self, prompt="", max_words=50):
|
||||||
|
# Note: Actually no prompts are used here.
|
||||||
|
super().__init__()
|
||||||
|
self.prompt = prompt
|
||||||
|
self.max_words = max_words
|
||||||
|
|
||||||
|
def __call__(self, caption):
|
||||||
|
caption = self.prompt + self.pre_caption(caption)
|
||||||
|
|
||||||
|
return caption
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg=None):
|
||||||
|
if cfg is None:
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
prompt = cfg.get("prompt", "")
|
||||||
|
max_words = cfg.get("max_words", 50)
|
||||||
|
|
||||||
|
return cls(prompt=prompt, max_words=max_words)
|
||||||
|
|
||||||
|
def pre_caption(self, caption):
|
||||||
|
caption = re.sub(
|
||||||
|
r"([.!\"()*#:;~])",
|
||||||
|
" ",
|
||||||
|
caption.lower(),
|
||||||
|
)
|
||||||
|
caption = re.sub(
|
||||||
|
r"\s{2,}",
|
||||||
|
" ",
|
||||||
|
caption,
|
||||||
|
)
|
||||||
|
caption = caption.rstrip("\n")
|
||||||
|
caption = caption.strip(" ")
|
||||||
|
|
||||||
|
# truncate caption
|
||||||
|
caption_words = caption.split(" ")
|
||||||
|
if len(caption_words) > self.max_words:
|
||||||
|
caption = " ".join(caption_words[: self.max_words])
|
||||||
|
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The training config of vision processor keeps the same as BLIP2 / MiniGPT4
|
||||||
|
@registry.register_processor("imagebind_vision_train")
|
||||||
|
class ImageBindVisionTrainProcessor(ImageBindVisionBaseProcessor):
|
||||||
|
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
|
||||||
|
super().__init__(mean=mean, std=std)
|
||||||
|
|
||||||
|
self.transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomResizedCrop(
|
||||||
|
image_size,
|
||||||
|
scale=(min_scale, max_scale),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
self.normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, item):
|
||||||
|
return self.transform(item)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg=None):
|
||||||
|
if cfg is None:
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
image_size = cfg.get("image_size", 224)
|
||||||
|
|
||||||
|
mean = cfg.get("mean", None)
|
||||||
|
std = cfg.get("std", None)
|
||||||
|
|
||||||
|
min_scale = cfg.get("min_scale", 0.5)
|
||||||
|
max_scale = cfg.get("max_scale", 1.0)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
image_size=image_size,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
min_scale=min_scale,
|
||||||
|
max_scale=max_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Changed.
|
||||||
|
@registry.register_processor("imagebind_vision_eval")
|
||||||
|
class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
|
||||||
|
def __init__(self, image_size=224, mean=None, std=None):
|
||||||
|
super().__init__(mean=mean, std=std)
|
||||||
|
|
||||||
|
self.transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(
|
||||||
|
image_size, interpolation=InterpolationMode.BICUBIC
|
||||||
|
),
|
||||||
|
transforms.CenterCrop(image_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
self.normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, item):
|
||||||
|
return self.transform(item)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg=None):
|
||||||
|
if cfg is None:
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
image_size = cfg.get("image_size", 224)
|
||||||
|
|
||||||
|
mean = cfg.get("mean", None)
|
||||||
|
std = cfg.get("std", None)
|
||||||
|
|
||||||
|
return cls(image_size=image_size, mean=mean, std=std)
|
Loading…
Reference in New Issue
Block a user