Build model sketches. Next: Dataset & Config

This commit is contained in:
unknown 2023-05-17 21:30:25 +08:00
parent 22d8888ca2
commit 926ee3e98a
72 changed files with 2451 additions and 4 deletions

2
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Default ignored files
/workspace.xml

15
.idea/MiniGPT-4.iml Normal file
View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="pytest" />
</component>
</module>

View File

@ -0,0 +1,7 @@
<component name="ProjectDictionaryState">
<dictionary name="Desmond">
<words>
<w>imagebind</w>
</words>
</dictionary>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/MiniGPT-4.iml" filepath="$PROJECT_DIR$/.idea/MiniGPT-4.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 538 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 679 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 555 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 468 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 658 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 690 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 713 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 597 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 190 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 603 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 634 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 305 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 588 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 805 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 853 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 567 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 712 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 565 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 538 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 679 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 555 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 468 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 658 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 690 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 713 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 597 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 190 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 603 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 634 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 305 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 588 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 805 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 853 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 567 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 712 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 565 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.4 MiB

0
imagebind/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,350 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torchaudio
import logging
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
from PIL import Image
from pytorchvideo import transforms as pv_transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision import transforms
from torchvision.transforms._transforms_video import NormalizeVideo
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
waveform -= waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sample_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mel_bins,
dither=0.0,
frame_length=25,
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
)
# Convert to [mel_bins, num_frames] shape
fbank = fbank.transpose(0, 1)
# Pad to target_length
n_frames = fbank.size(1)
p = target_length - n_frames
# if p is too large (say >20%), flash a warning
if abs(p) / n_frames > 0.2:
logging.warning(
"Large gap between audio n_frames(%d) and "
"target_length (%d). Is the audio_target_length "
"setting correct?",
n_frames,
target_length,
)
# cut and pad
if p > 0:
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
elif p < 0:
fbank = fbank[:, 0:target_length]
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
# channel image
fbank = fbank.unsqueeze(0)
return fbank
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def load_and_transform_vision_data(image_paths, device):
if image_paths is None:
return None
image_ouputs = []
for image_path in image_paths:
data_transform = transforms.Compose(
[
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
with open(image_path, "rb") as fopen:
image = Image.open(fopen).convert("RGB")
image = data_transform(image).to(device)
image_ouputs.append(image)
return torch.stack(image_ouputs, dim=0)
def load_and_transform_text(text, device):
if text is None:
return None
tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
tokens = torch.cat(tokens, dim=0)
return tokens
def load_and_transform_audio_data(
audio_paths,
device,
num_mel_bins=128,
target_length=204,
sample_rate=16000,
clip_duration=2,
clips_per_video=3,
mean=-4.268,
std=9.138,
):
if audio_paths is None:
return None
audio_outputs = []
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
for audio_path in audio_paths:
waveform, sr = torchaudio.load(audio_path)
if sample_rate != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=sample_rate
)
all_clips_timepoints = get_clip_timepoints(
clip_sampler, waveform.size(1) / sample_rate
)
all_clips = []
for clip_timepoints in all_clips_timepoints:
waveform_clip = waveform[
:,
int(clip_timepoints[0] * sample_rate) : int(
clip_timepoints[1] * sample_rate
),
]
waveform_melspec = waveform2melspec(
waveform_clip, sample_rate, num_mel_bins, target_length
)
all_clips.append(waveform_melspec)
normalize = transforms.Normalize(mean=mean, std=std)
all_clips = [normalize(ac).to(device) for ac in all_clips]
all_clips = torch.stack(all_clips, dim=0)
audio_outputs.append(all_clips)
return torch.stack(audio_outputs, dim=0)
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def crop_boxes(boxes, x_offset, y_offset):
"""
Perform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to perform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images,
size=(height, width),
mode="bilinear",
align_corners=False,
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 3:
self.crops_to_ext = [0, 1, 2]
self.flipped_crops_to_ext = []
elif num_crops == 1:
self.crops_to_ext = [1]
self.flipped_crops_to_ext = []
else:
raise NotImplementedError("Nothing else supported yet")
def forward(self, videos):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
if not self.flipped_crops_to_ext:
continue
flipped_video = transforms.functional.hflip(video)
for spatial_idx in self.flipped_crops_to_ext:
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
return res
def load_and_transform_video_data(
video_paths,
device,
clip_duration=2,
clips_per_video=5,
sample_rate=16000,
):
if video_paths is None:
return None
video_outputs = []
video_transform = transforms.Compose(
[
pv_transforms.ShortSideScale(224),
NormalizeVideo(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
for video_path in video_paths:
video = EncodedVideo.from_path(
video_path,
decoder="decord",
decode_audio=False,
**{"sample_rate": sample_rate},
)
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
all_video = []
for clip_timepoints in all_clips_timepoints:
# Read the clip, get frames
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
if clip is None:
raise ValueError("No clip found")
video_clip = frame_sampler(clip["video"])
video_clip = video_clip / 255.0 # since this is float, need 0-1
all_video.append(video_clip)
all_video = [video_transform(clip) for clip in all_video]
all_video = SpatialCrop(224, num_crops=3)(all_video)
all_video = torch.stack(all_video, dim=0)
video_outputs.append(all_video)
return torch.stack(video_outputs, dim=0).to(device)

View File

141
imagebind/models/helper.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import einops
import numpy as np
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x):
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
class LearnableLogitScaling(nn.Module):
def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
max_logit_scale: float = 100,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)
def forward(self, x):
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
def extra_repr(self):
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
return st
class EinOpsRearrange(nn.Module):
def __init__(self, rearrange_expr: str, **kwargs) -> None:
super().__init__()
self.rearrange_expr = rearrange_expr
self.kwargs = kwargs
def forward(self, x):
assert isinstance(x, torch.Tensor)
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
class VerboseNNModule(nn.Module):
"""
Wrapper around nn.Module that prints registered buffers and parameter names.
"""
@staticmethod
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
st = (
"("
+ name
+ "): "
+ "tensor("
+ str(tuple(tensor[1].shape))
+ ", requires_grad="
+ str(tensor[1].requires_grad)
+ ")\n"
)
return st
def extra_repr(self) -> str:
named_modules = set()
for p in self.named_modules():
named_modules.update([p[0]])
named_modules = list(named_modules)
string_repr = ""
for p in self.named_parameters():
name = p[0].split(".")[0]
if name not in named_modules:
string_repr += self.get_readable_tensor_repr(name, p)
for p in self.named_buffers():
name = p[0].split(".")[0]
string_repr += self.get_readable_tensor_repr(name, p)
return string_repr
def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class SelectElement(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x):
assert x.ndim >= 3
return x[:, self.index, ...]
class SelectEOSAndProject(nn.Module):
"""
Text Pooling used in OpenCLIP
"""
def __init__(self, proj: nn.Module) -> None:
super().__init__()
self.proj = proj
def forward(self, x, seq_len):
assert x.ndim == 3
# x is of shape B x L x D
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), seq_len]
x = self.proj(x)
return x

View File

@ -0,0 +1,587 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from functools import partial
from types import SimpleNamespace
from typing import Union, Optional, Tuple, Dict, List
import torch
import torch.nn as nn
from torch import Tensor
from imagebind.models.helper import (
EinOpsRearrange,
LearnableLogitScaling,
Normalize,
SelectElement,
SelectEOSAndProject,
)
from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train
from imagebind.models.multimodal_preprocessors import (
AudioPreprocessor,
IMUPreprocessor,
PadIm2Video,
PatchEmbedGeneric,
RGBDTPreprocessor,
SpatioTemporalPosEmbeddingHelper,
TextPreprocessor,
ThermalPreprocessor,
)
from imagebind.models.multimodal_projectors import create_projectors, create_pre_projector
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
ModalityType = SimpleNamespace(
VISION="Vision",
TEXT="Text",
AUDIO="Audio",
THERMAL="Thermal",
DEPTH="Depth",
IMU="Imu",
)
class ImageBindJoiner(nn.Module):
def __init__(self,
vision_query_token_num: int,
vision_qformer_frozen: bool = False,
vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
vision_pre_dims: List[int] = (), # Projection before Q-Former
vision_post_dims: List[int] = (768, 768) # Projection after Q-Former
):
super().__init__()
assert not (vision_qformer_frozen and vision_qformer_model == "")
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
vision_qformer_frozen,
vision_qformer_model)
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
def _create_modality_pre_projectors(self,
vision_pre_dims
):
modality_pre_projectors = {
ModalityType.VISION: create_projectors(vision_pre_dims)
}
return modality_pre_projectors
def _create_modality_qformers(self,
vision_query_token_num,
vision_qformer_frozen,
vision_qformer_model
):
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
freeze_qformer=vision_qformer_frozen,
q_former_model=vision_qformer_model)
modality_qformers = {
ModalityType.VISION: vision_qformer
}
return nn.ModuleDict(modality_qformers)
def _create_modality_post_projectors(self, vision_post_dims):
vision_projector = create_projectors(vision_post_dims)
modality_projectors = {
ModalityType.VISION: vision_projector
}
return nn.ModuleDict(modality_projectors)
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = {}
for modality_key, modality_value in inputs.items():
assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
if modality_value is not None:
modality_value = self.modality_pre_projectors[modality_key](modality_value)
modality_value = self.modality_qformers[modality_key](modality_value)
modality_value = self.modality_post_projectors[modality_key](modality_value)
outputs[modality_key] = modality_value
return outputs
class ImageBindModel(nn.Module):
def __init__(
self,
video_frames=2,
kernel_size=(2, 14, 14),
audio_kernel_size=16,
audio_stride=10,
out_embed_dim=768,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_num_mel_bins=128,
audio_target_len=204,
audio_drop_path=0.1,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
depth_embed_dim=384,
depth_kernel_size=16,
depth_num_blocks=12,
depth_num_heads=8,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_kernel_size=16,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_kernel_size=8,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
super().__init__()
self.modality_preprocessors = self._create_modality_preprocessors(
video_frames,
vision_embed_dim,
kernel_size,
text_embed_dim,
audio_embed_dim,
audio_kernel_size,
audio_stride,
audio_num_mel_bins,
audio_target_len,
depth_embed_dim,
depth_kernel_size,
thermal_embed_dim,
thermal_kernel_size,
imu_embed_dim,
)
self.modality_trunks = self._create_modality_trunks(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
text_embed_dim,
text_num_blocks,
text_num_heads,
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
audio_drop_path,
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
depth_drop_path,
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
thermal_drop_path,
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
imu_drop_path,
)
self.modality_heads = self._create_modality_heads(
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
)
self.modality_postprocessors = self._create_modality_postprocessors(
out_embed_dim
)
def _create_modality_preprocessors(
self,
video_frames=2,
vision_embed_dim=1024,
kernel_size=(2, 14, 14),
text_embed_dim=768,
audio_embed_dim=768,
audio_kernel_size=16,
audio_stride=10,
audio_num_mel_bins=128,
audio_target_len=204,
depth_embed_dim=768,
depth_kernel_size=16,
thermal_embed_dim=768,
thermal_kernel_size=16,
imu_embed_dim=512,
):
rgbt_stem = PatchEmbedGeneric(
proj_stem=[
PadIm2Video(pad_type="repeat", ntimes=2),
nn.Conv3d(
in_channels=3,
kernel_size=kernel_size,
out_channels=vision_embed_dim,
stride=kernel_size,
bias=False,
),
]
)
rgbt_preprocessor = RGBDTPreprocessor(
img_size=[3, video_frames, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=rgbt_stem,
depth_stem=None,
)
text_preprocessor = TextPreprocessor(
context_length=77,
vocab_size=49408,
embed_dim=text_embed_dim,
causal_masking=True,
)
audio_stem = PatchEmbedGeneric(
proj_stem=[
nn.Conv2d(
in_channels=1,
kernel_size=audio_kernel_size,
stride=audio_stride,
out_channels=audio_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
)
audio_preprocessor = AudioPreprocessor(
img_size=[1, audio_num_mel_bins, audio_target_len],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
audio_stem=audio_stem,
)
depth_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=depth_kernel_size,
in_channels=1,
out_channels=depth_embed_dim,
stride=depth_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
)
depth_preprocessor = RGBDTPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=None,
depth_stem=depth_stem,
)
thermal_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=thermal_kernel_size,
in_channels=1,
out_channels=thermal_embed_dim,
stride=thermal_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
)
thermal_preprocessor = ThermalPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
thermal_stem=thermal_stem,
)
imu_stem = PatchEmbedGeneric(
[
nn.Linear(
in_features=48,
out_features=imu_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
)
imu_preprocessor = IMUPreprocessor(
img_size=[6, 2000],
num_cls_tokens=1,
kernel_size=8,
embed_dim=imu_embed_dim,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
imu_stem=imu_stem,
)
modality_preprocessors = {
ModalityType.VISION: rgbt_preprocessor,
ModalityType.TEXT: text_preprocessor,
ModalityType.AUDIO: audio_preprocessor,
ModalityType.DEPTH: depth_preprocessor,
ModalityType.THERMAL: thermal_preprocessor,
ModalityType.IMU: imu_preprocessor,
}
return nn.ModuleDict(modality_preprocessors)
def _create_modality_trunks(
self,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_drop_path=0.0,
depth_embed_dim=768,
depth_num_blocks=12,
depth_num_heads=12,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
def instantiate_trunk(
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
):
return SimpleTransformer(
embed_dim=embed_dim,
num_blocks=num_blocks,
ffn_dropout_rate=0.0,
drop_path_rate=drop_path,
attn_target=partial(
MultiheadAttention,
embed_dim=embed_dim,
num_heads=num_heads,
bias=True,
add_bias_kv=add_bias_kv,
),
pre_transformer_layer=nn.Sequential(
nn.LayerNorm(embed_dim, eps=1e-6)
if pre_transformer_ln
else nn.Identity(),
EinOpsRearrange("b l d -> l b d"),
),
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
)
modality_trunks = {}
modality_trunks[ModalityType.VISION] = instantiate_trunk(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
pre_transformer_ln=True,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.TEXT] = instantiate_trunk(
text_embed_dim,
text_num_blocks,
text_num_heads,
pre_transformer_ln=False,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=audio_drop_path,
)
modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=depth_drop_path,
)
modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=thermal_drop_path,
)
modality_trunks[ModalityType.IMU] = instantiate_trunk(
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=imu_drop_path,
)
return nn.ModuleDict(modality_trunks)
def _create_modality_heads(
self,
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
):
modality_heads = {}
modality_heads[ModalityType.VISION] = nn.Sequential(
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
proj=nn.Sequential(
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
nn.Linear(text_embed_dim, out_embed_dim, bias=False),
)
)
modality_heads[ModalityType.AUDIO] = nn.Sequential(
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.DEPTH] = nn.Sequential(
nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.THERMAL] = nn.Sequential(
nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.IMU] = nn.Sequential(
nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Dropout(p=0.5),
nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
)
return nn.ModuleDict(modality_heads)
def _create_modality_postprocessors(self, out_embed_dim):
modality_postprocessors = {}
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
Normalize(dim=-1), LearnableLogitScaling(learnable=True)
)
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
)
modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
)
modality_postprocessors[ModalityType.IMU] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
return nn.ModuleDict(modality_postprocessors)
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
# NOTE: No heads are needed any more.
# head_inputs = modality_value["head"]
# modality_value = self.modality_heads[modality_key](
# modality_value, **head_inputs
# )
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
# NOTE: The reduction operation has been modified.
if reduce_list:
modality_value = modality_value.reshape(B, S, *modality_value[2:])
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value
return outputs
def imagebind_huge(pretrained=False, freeze_imagebind=False):
model = ImageBindModel(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
)
if pretrained:
if not os.path.exists(".checkpoints/imagebind_huge.pth"):
print(
"Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
)
os.makedirs(".checkpoints", exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
".checkpoints/imagebind_huge.pth",
progress=True,
)
model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
if freeze_imagebind:
for name, param in model.named_parameters():
param.requires_grad = False
model = model.eval()
model.train = disabled_train
return model

View File

@ -0,0 +1,103 @@
import logging
import os
import torch
from torch import nn, Tensor
from minigpt4.common.dist_utils import download_cached_file
from minigpt4.common.utils import is_url
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class BaseQFormer(nn.Module):
def __init__(self, freeze_qformer=False):
super().__init__()
self.freeze_qformer = freeze_qformer
self.Qformer = None
def check_and_freeze(self):
assert self.Qformer is not None
if self.freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("Freeze This QFormer")
@classmethod
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
class SequenceGenericQFormer(BaseQFormer):
def __init__(self,
num_query_token: int,
encoder_width: int = 768,
freeze_qformer: bool = False,
q_former_model: str = "",
cross_attention_freq: int = 2
):
super().__init__(freeze_qformer)
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq)
if q_former_model != "":
self.load_Qformer(q_former_model)
self.check_and_freeze()
def load_Qformer(self, q_former_model):
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
@classmethod
def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = encoder_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
def forward(self, input_embeds: Tensor) -> Tensor:
input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=input_embeds,
encoder_attention_mask=input_atts,
return_dict=True,
)
return query_output.last_hidden_state

View File

@ -0,0 +1,686 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import html
import io
import math
from functools import lru_cache
from typing import Callable, List, Optional
import ftfy
import numpy as np
import regex as re
import torch
import torch.nn as nn
from iopath.common.file_io import g_pathmgr
from timm.models.layers import trunc_normal_
from imagebind.models.helper import VerboseNNModule, cast_if_src_dtype
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
N = pos_embed.shape[1]
if N == target_spatial_size:
return pos_embed
dim = pos_embed.shape[-1]
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
pos_embed = nn.functional.interpolate(
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
0, 3, 1, 2
),
scale_factor=math.sqrt(target_spatial_size / N),
mode="bicubic",
)
if updated:
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return pos_embed
def interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=None,
first_patch_idx=1,
):
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
if npatch_per_img == N:
return pos_embed
assert (
patches_layout[-1] == patches_layout[-2]
), "Interpolation of pos embed not supported for non-square layouts"
class_emb = pos_embed[:, :first_patch_idx]
pos_embed = pos_embed[:, first_patch_idx:]
if input_shape is None or patches_layout[0] == 1:
# simple 2D pos embedding, no temporal component
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
elif patches_layout[0] > 1:
# pos embed has a temporal component
assert len(input_shape) == 4, "temporal interpolation not supported"
# we only support 2D interpolation in this case
num_frames = patches_layout[0]
num_spatial_tokens = patches_layout[1] * patches_layout[2]
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
# interpolate embedding for zeroth frame
pos_embed = interpolate_pos_encoding_2d(
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
)
else:
raise ValueError("This type of interpolation isn't implemented")
return torch.cat((class_emb, pos_embed), dim=1)
def _get_pos_embedding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape,
first_patch_idx=1,
):
pos_embed = interpolate_pos_encoding(
npatch_per_img,
pos_embed,
patches_layout,
input_shape=input_shape,
first_patch_idx=first_patch_idx,
)
return pos_embed
class PatchEmbedGeneric(nn.Module):
"""
PatchEmbed from Hydra
"""
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
super().__init__()
if len(proj_stem) > 1:
self.proj = nn.Sequential(*proj_stem)
else:
# Special case to be able to load pre-trained models that were
# trained with a standard stem
self.proj = proj_stem[0]
self.norm_layer = norm_layer
def get_patch_layout(self, img_size):
with torch.no_grad():
dummy_img = torch.zeros(
[
1,
]
+ img_size
)
dummy_out = self.proj(dummy_img)
embed_dim = dummy_out.shape[1]
patches_layout = tuple(dummy_out.shape[2:])
num_patches = np.prod(patches_layout)
return patches_layout, num_patches, embed_dim
def forward(self, x):
x = self.proj(x)
# B C (T) H W -> B (T)HW C
x = x.flatten(2).transpose(1, 2)
if self.norm_layer is not None:
x = self.norm_layer(x)
return x
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
def __init__(
self,
patches_layout: List,
num_patches: int,
num_cls_tokens: int,
embed_dim: int,
learnable: bool,
) -> None:
super().__init__()
self.num_cls_tokens = num_cls_tokens
self.patches_layout = patches_layout
self.num_patches = num_patches
self.num_tokens = num_cls_tokens + num_patches
self.learnable = learnable
if self.learnable:
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
trunc_normal_(self.pos_embed, std=0.02)
else:
self.register_buffer(
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
)
def get_pos_embedding(self, vision_input, all_vision_tokens):
input_shape = vision_input.shape
pos_embed = _get_pos_embedding(
all_vision_tokens.size(1) - self.num_cls_tokens,
pos_embed=self.pos_embed,
patches_layout=self.patches_layout,
input_shape=input_shape,
first_patch_idx=self.num_cls_tokens,
)
return pos_embed
class RGBDTPreprocessor(VerboseNNModule):
def __init__(
self,
rgbt_stem: PatchEmbedGeneric,
depth_stem: PatchEmbedGeneric,
img_size: List = (3, 224, 224),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
use_type_embed: bool = False,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = rgbt_stem if rgbt_stem is not None else depth_stem
(
self.patches_layout,
self.num_patches,
self.embed_dim,
) = stem.get_patch_layout(img_size)
self.rgbt_stem = rgbt_stem
self.depth_stem = depth_stem
self.use_pos_embed = pos_embed_fn is not None
self.use_type_embed = use_type_embed
self.num_cls_tokens = num_cls_tokens
if self.use_pos_embed:
self.pos_embedding_helper = pos_embed_fn(
patches_layout=self.patches_layout,
num_cls_tokens=num_cls_tokens,
num_patches=self.num_patches,
embed_dim=self.embed_dim,
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
if self.use_type_embed:
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.use_pos_embed:
nn.init.normal_(self.pos_embedding_helper.pos_embed)
self.pos_embedding_helper.pos_embed *= scale
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
if self.use_type_embed:
nn.init.normal_(self.type_embed)
def tokenize_input_and_cls_pos(self, input, stem, mask):
# tokens is of shape B x L x D
tokens = stem(input)
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
tokens = tokens + pos_embed
if self.use_type_embed:
tokens = tokens + self.type_embed.expand(B, -1, -1)
return tokens
def forward(self, vision=None, depth=None, patch_mask=None):
if patch_mask is not None:
raise NotImplementedError()
if vision is not None:
vision_tokens = self.tokenize_input_and_cls_pos(
vision, self.rgbt_stem, patch_mask
)
if depth is not None:
depth_tokens = self.tokenize_input_and_cls_pos(
depth, self.depth_stem, patch_mask
)
# aggregate tokens
if vision is not None and depth is not None:
final_tokens = vision_tokens + depth_tokens
else:
final_tokens = vision_tokens if vision is not None else depth_tokens
return_dict = {
"trunk": {
"tokens": final_tokens,
},
"head": {},
}
return return_dict
class AudioPreprocessor(RGBDTPreprocessor):
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
def forward(self, audio=None):
return super().forward(vision=audio)
class ThermalPreprocessor(RGBDTPreprocessor):
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
def forward(self, thermal=None):
return super().forward(vision=thermal)
def build_causal_attention_mask(context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(context_length, context_length, requires_grad=False)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
class TextPreprocessor(VerboseNNModule):
def __init__(
self,
vocab_size: int,
context_length: int,
embed_dim: int,
causal_masking: bool,
supply_seq_len_to_head: bool = True,
num_cls_tokens: int = 0,
init_param_style: str = "openclip",
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Parameter(
torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
)
self.causal_masking = causal_masking
if self.causal_masking:
mask = build_causal_attention_mask(self.context_length)
# register the mask as a buffer so it can be moved to the right device
self.register_buffer("mask", mask)
self.supply_seq_len_to_head = supply_seq_len_to_head
self.num_cls_tokens = num_cls_tokens
self.embed_dim = embed_dim
if num_cls_tokens > 0:
assert self.causal_masking is False, "Masking + CLS token isn't implemented"
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style="openclip"):
# OpenCLIP style initialization
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def forward(self, text):
# text tokens are of shape B x L x D
text_tokens = self.token_embedding(text)
# concat CLS tokens if any
if self.num_cls_tokens > 0:
B = text_tokens.shape[0]
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
text_tokens = text_tokens + self.pos_embed
return_dict = {
"trunk": {
"tokens": text_tokens,
},
"head": {},
}
# Compute sequence length after adding CLS tokens
if self.supply_seq_len_to_head:
text_lengths = text.argmax(dim=-1)
return_dict["head"] = {
"seq_len": text_lengths,
}
if self.causal_masking:
return_dict["trunk"].update({"attn_mask": self.mask})
return return_dict
class Im2Video(nn.Module):
"""Convert an image into a trivial video."""
def __init__(self, time_dim=2):
super().__init__()
self.time_dim = time_dim
def forward(self, x):
if x.ndim == 4:
# B, C, H, W -> B, C, T, H, W
return x.unsqueeze(self.time_dim)
elif x.ndim == 5:
return x
else:
raise ValueError(f"Dimension incorrect {x.shape}")
class PadIm2Video(Im2Video):
def __init__(self, ntimes, pad_type, time_dim=2):
super().__init__(time_dim=time_dim)
assert ntimes > 0
assert pad_type in ["zero", "repeat"]
self.ntimes = ntimes
self.pad_type = pad_type
def forward(self, x):
x = super().forward(x)
if x.shape[self.time_dim] == 1:
if self.pad_type == "repeat":
new_shape = [1] * len(x.shape)
new_shape[self.time_dim] = self.ntimes
x = x.repeat(new_shape)
elif self.pad_type == "zero":
padarg = [0, 0] * len(x.shape)
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
x = nn.functional.pad(x, padarg)
return x
# Modified from github.com/openai/CLIP
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str, context_length=77):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with g_pathmgr.open(bpe_path, "rb") as fh:
bpe_bytes = io.BytesIO(fh.read())
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>",
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.context_length = context_length
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
def __call__(self, texts, context_length=None):
if not context_length:
context_length = self.context_length
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder["<|startoftext|>"]
eot_token = self.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
tokens = tokens[:context_length]
result[i, : len(tokens)] = torch.tensor(tokens)
if len(result) == 1:
return result[0]
return result
class IMUPreprocessor(VerboseNNModule):
def __init__(
self,
kernel_size: int,
imu_stem: PatchEmbedGeneric,
embed_dim: int,
img_size: List = (6, 2000),
num_cls_tokens: int = 1,
pos_embed_fn: Callable = None,
init_param_style: str = "openclip",
) -> None:
super().__init__()
stem = imu_stem
self.imu_stem = imu_stem
self.embed_dim = embed_dim
self.use_pos_embed = pos_embed_fn is not None
self.num_cls_tokens = num_cls_tokens
self.kernel_size = kernel_size
self.pos_embed = nn.Parameter(
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
)
if self.num_cls_tokens > 0:
self.cls_token = nn.Parameter(
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
)
self.init_parameters(init_param_style)
@torch.no_grad()
def init_parameters(self, init_param_style):
nn.init.normal_(self.pos_embed, std=0.01)
if init_param_style == "openclip":
# OpenCLIP style initialization
scale = self.embed_dim**-0.5
if self.num_cls_tokens > 0:
nn.init.normal_(self.cls_token)
self.cls_token *= scale
elif init_param_style == "vit":
self.cls_token.data.fill_(0)
else:
raise ValueError(f"Unknown init {init_param_style}")
def tokenize_input_and_cls_pos(self, input, stem):
# tokens is of shape B x L x D
tokens = stem.norm_layer(stem.proj(input))
assert tokens.ndim == 3
assert tokens.shape[2] == self.embed_dim
B = tokens.shape[0]
if self.num_cls_tokens > 0:
class_tokens = self.cls_token.expand(
B, -1, -1
) # stole class_tokens impl from Phil Wang, thanks
tokens = torch.cat((class_tokens, tokens), dim=1)
if self.use_pos_embed:
tokens = tokens + self.pos_embed
return tokens
def forward(self, imu):
# Patchify
imu = imu.unfold(
-1,
self.kernel_size,
self.kernel_size,
).permute(0, 2, 1, 3)
imu = imu.reshape(imu.size(0), imu.size(1), -1)
imu_tokens = self.tokenize_input_and_cls_pos(
imu,
self.imu_stem,
)
return_dict = {
"trunk": {
"tokens": imu_tokens,
},
"head": {},
}
return return_dict

View File

@ -0,0 +1,45 @@
from torch import nn, Tensor
from typing import Union, Optional, Tuple
class BaseProjector(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: Tensor) -> Tensor:
raise NotImplementedError
class LinearProjector(BaseProjector):
def __init__(self, in_dim, out_dim):
super().__init__()
self.fc = nn.Linear(in_dim, out_dim)
def forward(self, x: Tensor) -> Tensor:
return self.fc(x)
class AdapterProjector(BaseProjector):
def __init__(self, in_dim, mid_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_dim, mid_dim, bias=False),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, out_dim, bias=False),
nn.ReLU(inplace=True)
)
def forward(self, x: Tensor) -> Tensor:
return self.fc(x)
def create_projectors(dims):
if len(dims) == 0:
return nn.Identity()
elif len(dims) == 2:
return LinearProjector(*dims)
elif len(dims) == 3:
return AdapterProjector(*dims)
else:
raise NotImplementedError

View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Code modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
# https://github.com/facebookresearch/deit/blob/main/models.py
# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
import copy
import fnmatch
import logging
from functools import partial
from typing import Callable, List
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, trunc_normal_
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version,
# can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MultiheadAttention(nn.MultiheadAttention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
class ViTAttention(Attention):
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
assert attn_mask is None
return super().forward(x)
class BlockWithMasking(nn.Module):
def __init__(
self,
dim: int,
attn_target: Callable,
mlp_ratio: int = 4,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
ffn_dropout_rate: float = 0.0,
drop_path: float = 0.0,
layer_scale_type: str = None,
layer_scale_init_value: float = 1e-4,
):
super().__init__()
assert not isinstance(
attn_target, nn.Module
), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
self.attn = attn_target()
if drop_path > 0.0:
self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm_1 = norm_layer(dim)
mlp_hidden_dim = int(mlp_ratio * dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=ffn_dropout_rate,
)
self.norm_2 = norm_layer(dim)
self.layer_scale_type = layer_scale_type
if self.layer_scale_type is not None:
assert self.layer_scale_type in [
"per_channel",
"scalar",
], f"Found Layer scale type {self.layer_scale_type}"
if self.layer_scale_type == "per_channel":
# one gamma value per channel
gamma_shape = [1, 1, dim]
elif self.layer_scale_type == "scalar":
# single gamma value for all channels
gamma_shape = [1, 1, 1]
# two gammas: for each part of the fwd in the encoder
self.layer_scale_gamma1 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
self.layer_scale_gamma2 = nn.Parameter(
torch.ones(size=gamma_shape) * layer_scale_init_value,
requires_grad=True,
)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
if self.layer_scale_type is None:
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
x = x + self.drop_path(self.mlp(self.norm_2(x)))
else:
x = (
x
+ self.drop_path(self.attn(self.norm_1(x), attn_mask))
* self.layer_scale_gamma1
)
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
return x
_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
class SimpleTransformer(nn.Module):
def __init__(
self,
attn_target: Callable,
embed_dim: int,
num_blocks: int,
block: Callable = BlockWithMasking,
pre_transformer_layer: Callable = None,
post_transformer_layer: Callable = None,
drop_path_rate: float = 0.0,
drop_path_type: str = "progressive",
norm_layer: Callable = _LAYER_NORM,
mlp_ratio: int = 4,
ffn_dropout_rate: float = 0.0,
layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
layer_scale_init_value: float = 1e-4, # from cait; float
weight_init_style: str = "jax", # possible values jax or pytorch
):
"""
Simple Transformer with the following features
1. Supports masked attention
2. Supports DropPath
3. Supports LayerScale
4. Supports Dropout in Attention and FFN
5. Makes few assumptions about the input except that it is a Tensor
"""
super().__init__()
self.pre_transformer_layer = pre_transformer_layer
if drop_path_type == "progressive":
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
elif drop_path_type == "uniform":
dpr = [drop_path_rate for i in range(num_blocks)]
else:
raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
self.blocks = nn.Sequential(
*[
block(
dim=embed_dim,
attn_target=attn_target,
mlp_ratio=mlp_ratio,
ffn_dropout_rate=ffn_dropout_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
layer_scale_type=layer_scale_type,
layer_scale_init_value=layer_scale_init_value,
)
for i in range(num_blocks)
]
)
self.post_transformer_layer = post_transformer_layer
self.weight_init_style = weight_init_style
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
if self.weight_init_style == "jax":
# Based on MAE and official Jax ViT implementation
torch.nn.init.xavier_uniform_(m.weight)
elif self.weight_init_style == "pytorch":
# PyTorch ViT uses trunc_normal_
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
tokens: torch.Tensor,
attn_mask: torch.Tensor = None,
use_checkpoint: bool = False,
checkpoint_every_n: int = 1,
checkpoint_blk_ids: List[int] = None,
):
"""
Inputs
- tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
- attn: mask of shape L x L
Output
- x: data of shape N x L x D (or L x N x D depending on the attention implementation)
"""
if self.pre_transformer_layer:
tokens = self.pre_transformer_layer(tokens)
if use_checkpoint and checkpoint_blk_ids is None:
checkpoint_blk_ids = [
blk_id
for blk_id in range(len(self.blocks))
if blk_id % checkpoint_every_n == 0
]
if checkpoint_blk_ids:
checkpoint_blk_ids = set(checkpoint_blk_ids)
for blk_id, blk in enumerate(self.blocks):
if use_checkpoint and blk_id in checkpoint_blk_ids:
tokens = checkpoint.checkpoint(
blk, tokens, attn_mask, use_reentrant=False
)
else:
tokens = blk(tokens, attn_mask=attn_mask)
if self.post_transformer_layer:
tokens = self.post_transformer_layer(tokens)
return tokens

View File

View File

@ -0,0 +1,200 @@
import random
from typing import Dict, Tuple
import torch
from torch import Tensor
from transformers import LlamaTokenizer
from imagebind.models.image_bind import imagebind_huge, ImageBindJoiner, ModalityType
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import BaseModel
from minigpt4.models.modeling_llama import LlamaForCausalLM
@registry.register_model("bind_gpt4")
class BindGPT4(BaseModel):
"""
ImageBind GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/bindgpt4.yaml",
}
def __init__(
self,
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
freeze_imagebind=True,
freeze_qformer=False,
num_query_token=32,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0 # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__()
assert not low_resource, "Low Resource Mode is Currently Unavailable."
self.low_resource = low_resource
print('Loading ImageBind')
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
print('Loading ImageBind Done')
print('Loading Q-Former and Adapter/Projector')
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
vision_qformer_frozen=freeze_qformer
# vision_qformer_model=q_former_model,
# vision_pre_dims=(1280, 1408)
)
print('Loading Q-Former and Adapter/Projector Done')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
self.max_txt_len = max_txt_len
self.end_sym = end_sym
print("Preparing Prompts")
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
self.prompt_list = [prompt_template.format(p) for p in raw_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
print("Preparing Prompts Done")
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
imagebind_outputs = self.multimodal_encoder(inputs)
llama_inputs = self.multimodal_joiner(imagebind_outputs)
return llama_inputs
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
# TODO: Accept More Modalities.
input_embeds = inputs[modality_name]
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
if prompt:
batch_size = input_embeds.shape[0]
p_before, p_after = prompt.split('<{}Here>'.format(modality_name))
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
return wrapped_input_embeds, wrapped_atts_input
else:
return input_embeds, attns_input
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
TODO: More Modalities.
Only accept image inputs here.
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
"""
embeds = self.encode_inputs(inputs)
assert "Vision" in embeds
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
# NOTE: No modifications from the next line to the end. Except for the autocast part.
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in inputs["text_input"]]
to_regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(img_embeds.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)
batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(bos)
atts_bos = atts_img[:, :1]
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss
return {"loss": loss}
@classmethod
def from_config(cls, cfg):
q_former_model = cfg.get("q_former_model",
"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
freeze_imagebind = cfg.get("freeze_imagebind", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
model = cls(
q_former_model=q_former_model,
freeze_imagebind=freeze_imagebind,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load ImageBind-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model

View File

@ -1,4 +1,4 @@
<Img><ImageHere></Img> Describe this image in detail.
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
<Img><ImageHere></Img> Please provide a detailed description of the picture.
<Img><ImageHere></Img> Could you describe the contents of this image for me?
<Vision><VisionHere></Vision> Describe this image in detail.
<Vision><VisionHere></Vision> Take a look at this image and describe what you notice.
<Vision><VisionHere></Vision> Please provide a detailed description of the picture.
<Vision><VisionHere></Vision> Could you describe the contents of this image for me?