diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..e7e9d11 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml diff --git a/.idea/MiniGPT-4.iml b/.idea/MiniGPT-4.iml new file mode 100644 index 0000000..4f2c9af --- /dev/null +++ b/.idea/MiniGPT-4.iml @@ -0,0 +1,15 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/dictionaries/Desmond.xml b/.idea/dictionaries/Desmond.xml new file mode 100644 index 0000000..71faffe --- /dev/null +++ b/.idea/dictionaries/Desmond.xml @@ -0,0 +1,7 @@ + + + + imagebind + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..8656114 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..8ab91a0 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/MiniGPT_4.pdf b/MiniGPT_4.pdf deleted file mode 100644 index 5450815..0000000 Binary files a/MiniGPT_4.pdf and /dev/null differ diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml new file mode 100644 index 0000000..e69de29 diff --git a/examples/ad_1.png b/examples/ad_1.png deleted file mode 100644 index d0378e4..0000000 Binary files a/examples/ad_1.png and /dev/null differ diff --git a/examples/ad_2.png b/examples/ad_2.png deleted file mode 100644 index 674248b..0000000 Binary files a/examples/ad_2.png and /dev/null differ diff --git a/examples/cook_1.png b/examples/cook_1.png deleted file mode 100644 index d8cdb45..0000000 Binary files a/examples/cook_1.png and /dev/null differ diff --git a/examples/cook_2.png b/examples/cook_2.png deleted file mode 100644 index d08272b..0000000 Binary files a/examples/cook_2.png and /dev/null differ diff --git a/examples/describe_1.png b/examples/describe_1.png deleted file mode 100644 index 02f3c92..0000000 Binary files a/examples/describe_1.png and /dev/null differ diff --git a/examples/describe_2.png b/examples/describe_2.png deleted file mode 100644 index 20bf8c7..0000000 Binary files a/examples/describe_2.png and /dev/null differ diff --git a/examples/fact_1.png b/examples/fact_1.png deleted file mode 100644 index 1f75228..0000000 Binary files a/examples/fact_1.png and /dev/null differ diff --git a/examples/fact_2.png b/examples/fact_2.png deleted file mode 100644 index de6ef53..0000000 Binary files a/examples/fact_2.png and /dev/null differ diff --git a/examples/fix_1.png b/examples/fix_1.png deleted file mode 100644 index 023cfe6..0000000 Binary files a/examples/fix_1.png and /dev/null differ diff --git a/examples/fix_2.png b/examples/fix_2.png deleted file mode 100644 index f60da5f..0000000 Binary files a/examples/fix_2.png and /dev/null differ diff --git a/examples/fun_1.png b/examples/fun_1.png deleted file mode 100644 index f720ea6..0000000 Binary files a/examples/fun_1.png and /dev/null differ diff --git a/examples/fun_2.png b/examples/fun_2.png deleted file mode 100644 index 1d37a80..0000000 Binary files a/examples/fun_2.png and /dev/null differ diff --git a/examples/logo_1.png b/examples/logo_1.png deleted file mode 100644 index 8bbe438..0000000 Binary files a/examples/logo_1.png and /dev/null differ diff --git a/examples/op_1.png b/examples/op_1.png deleted file mode 100644 index 3dbb2ff..0000000 Binary files a/examples/op_1.png and /dev/null differ diff --git a/examples/op_2.png b/examples/op_2.png deleted file mode 100644 index 2cd3e1f..0000000 Binary files a/examples/op_2.png and /dev/null differ diff --git a/examples/people_1.png b/examples/people_1.png deleted file mode 100644 index 7e95c42..0000000 Binary files a/examples/people_1.png and /dev/null differ diff --git a/examples/people_2.png b/examples/people_2.png deleted file mode 100644 index aec6c83..0000000 Binary files a/examples/people_2.png and /dev/null differ diff --git a/examples/rhyme_1.png b/examples/rhyme_1.png deleted file mode 100644 index 7d13387..0000000 Binary files a/examples/rhyme_1.png and /dev/null differ diff --git a/examples/rhyme_2.png b/examples/rhyme_2.png deleted file mode 100644 index 6cf9bf8..0000000 Binary files a/examples/rhyme_2.png and /dev/null differ diff --git a/examples/story_1.png b/examples/story_1.png deleted file mode 100644 index 3eb6ccb..0000000 Binary files a/examples/story_1.png and /dev/null differ diff --git a/examples/story_2.png b/examples/story_2.png deleted file mode 100644 index 9d37142..0000000 Binary files a/examples/story_2.png and /dev/null differ diff --git a/examples/web_1.png b/examples/web_1.png deleted file mode 100644 index 8943842..0000000 Binary files a/examples/web_1.png and /dev/null differ diff --git a/examples/wop_1.png b/examples/wop_1.png deleted file mode 100644 index 88f37d6..0000000 Binary files a/examples/wop_1.png and /dev/null differ diff --git a/examples/wop_2.png b/examples/wop_2.png deleted file mode 100644 index 8255974..0000000 Binary files a/examples/wop_2.png and /dev/null differ diff --git a/figs/examples/ad_1.png b/figs/examples/ad_1.png deleted file mode 100644 index d0378e4..0000000 Binary files a/figs/examples/ad_1.png and /dev/null differ diff --git a/figs/examples/ad_2.png b/figs/examples/ad_2.png deleted file mode 100644 index 674248b..0000000 Binary files a/figs/examples/ad_2.png and /dev/null differ diff --git a/figs/examples/cook_1.png b/figs/examples/cook_1.png deleted file mode 100644 index d8cdb45..0000000 Binary files a/figs/examples/cook_1.png and /dev/null differ diff --git a/figs/examples/cook_2.png b/figs/examples/cook_2.png deleted file mode 100644 index d08272b..0000000 Binary files a/figs/examples/cook_2.png and /dev/null differ diff --git a/figs/examples/describe_1.png b/figs/examples/describe_1.png deleted file mode 100644 index 02f3c92..0000000 Binary files a/figs/examples/describe_1.png and /dev/null differ diff --git a/figs/examples/describe_2.png b/figs/examples/describe_2.png deleted file mode 100644 index 20bf8c7..0000000 Binary files a/figs/examples/describe_2.png and /dev/null differ diff --git a/figs/examples/fact_1.png b/figs/examples/fact_1.png deleted file mode 100644 index 1f75228..0000000 Binary files a/figs/examples/fact_1.png and /dev/null differ diff --git a/figs/examples/fact_2.png b/figs/examples/fact_2.png deleted file mode 100644 index de6ef53..0000000 Binary files a/figs/examples/fact_2.png and /dev/null differ diff --git a/figs/examples/fix_1.png b/figs/examples/fix_1.png deleted file mode 100644 index 023cfe6..0000000 Binary files a/figs/examples/fix_1.png and /dev/null differ diff --git a/figs/examples/fix_2.png b/figs/examples/fix_2.png deleted file mode 100644 index f60da5f..0000000 Binary files a/figs/examples/fix_2.png and /dev/null differ diff --git a/figs/examples/fun_1.png b/figs/examples/fun_1.png deleted file mode 100644 index f720ea6..0000000 Binary files a/figs/examples/fun_1.png and /dev/null differ diff --git a/figs/examples/fun_2.png b/figs/examples/fun_2.png deleted file mode 100644 index 1d37a80..0000000 Binary files a/figs/examples/fun_2.png and /dev/null differ diff --git a/figs/examples/logo_1.png b/figs/examples/logo_1.png deleted file mode 100644 index 8bbe438..0000000 Binary files a/figs/examples/logo_1.png and /dev/null differ diff --git a/figs/examples/op_1.png b/figs/examples/op_1.png deleted file mode 100644 index 3dbb2ff..0000000 Binary files a/figs/examples/op_1.png and /dev/null differ diff --git a/figs/examples/op_2.png b/figs/examples/op_2.png deleted file mode 100644 index 2cd3e1f..0000000 Binary files a/figs/examples/op_2.png and /dev/null differ diff --git a/figs/examples/people_1.png b/figs/examples/people_1.png deleted file mode 100644 index 7e95c42..0000000 Binary files a/figs/examples/people_1.png and /dev/null differ diff --git a/figs/examples/people_2.png b/figs/examples/people_2.png deleted file mode 100644 index aec6c83..0000000 Binary files a/figs/examples/people_2.png and /dev/null differ diff --git a/figs/examples/rhyme_1.png b/figs/examples/rhyme_1.png deleted file mode 100644 index 7d13387..0000000 Binary files a/figs/examples/rhyme_1.png and /dev/null differ diff --git a/figs/examples/rhyme_2.png b/figs/examples/rhyme_2.png deleted file mode 100644 index 6cf9bf8..0000000 Binary files a/figs/examples/rhyme_2.png and /dev/null differ diff --git a/figs/examples/story_1.png b/figs/examples/story_1.png deleted file mode 100644 index 3eb6ccb..0000000 Binary files a/figs/examples/story_1.png and /dev/null differ diff --git a/figs/examples/story_2.png b/figs/examples/story_2.png deleted file mode 100644 index 9d37142..0000000 Binary files a/figs/examples/story_2.png and /dev/null differ diff --git a/figs/examples/web_1.png b/figs/examples/web_1.png deleted file mode 100644 index 8943842..0000000 Binary files a/figs/examples/web_1.png and /dev/null differ diff --git a/figs/examples/wop_1.png b/figs/examples/wop_1.png deleted file mode 100644 index 88f37d6..0000000 Binary files a/figs/examples/wop_1.png and /dev/null differ diff --git a/figs/examples/wop_2.png b/figs/examples/wop_2.png deleted file mode 100644 index 8255974..0000000 Binary files a/figs/examples/wop_2.png and /dev/null differ diff --git a/figs/online_demo.png b/figs/online_demo.png deleted file mode 100644 index 716e438..0000000 Binary files a/figs/online_demo.png and /dev/null differ diff --git a/figs/overview.png b/figs/overview.png deleted file mode 100644 index 10b952e..0000000 Binary files a/figs/overview.png and /dev/null differ diff --git a/imagebind/__init__.py b/imagebind/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imagebind/data/__init__.py b/imagebind/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imagebind/data/data_utils.py b/imagebind/data/data_utils.py new file mode 100644 index 0000000..49ec571 --- /dev/null +++ b/imagebind/data/data_utils.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torchaudio +import logging + +from imagebind.models.multimodal_preprocessors import SimpleTokenizer +from PIL import Image +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler +from pytorchvideo.data.encoded_video import EncodedVideo + +from torchvision import transforms +from torchvision.transforms._transforms_video import NormalizeVideo + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds + +BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" + + +def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): + # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 + waveform -= waveform.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform, + htk_compat=True, + sample_frequency=sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + # Convert to [mel_bins, num_frames] shape + fbank = fbank.transpose(0, 1) + # Pad to target_length + n_frames = fbank.size(1) + p = target_length - n_frames + # if p is too large (say >20%), flash a warning + if abs(p) / n_frames > 0.2: + logging.warning( + "Large gap between audio n_frames(%d) and " + "target_length (%d). Is the audio_target_length " + "setting correct?", + n_frames, + target_length, + ) + # cut and pad + if p > 0: + fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) + elif p < 0: + fbank = fbank[:, 0:target_length] + # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 + # channel image + fbank = fbank.unsqueeze(0) + return fbank + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def load_and_transform_vision_data(image_paths, device): + if image_paths is None: + return None + + image_ouputs = [] + for image_path in image_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + image = data_transform(image).to(device) + image_ouputs.append(image) + return torch.stack(image_ouputs, dim=0) + + +def load_and_transform_text(text, device): + if text is None: + return None + tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) + tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] + tokens = torch.cat(tokens, dim=0) + return tokens + + +def load_and_transform_audio_data( + audio_paths, + device, + num_mel_bins=128, + target_length=204, + sample_rate=16000, + clip_duration=2, + clips_per_video=3, + mean=-4.268, + std=9.138, +): + if audio_paths is None: + return None + + audio_outputs = [] + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + + for audio_path in audio_paths: + waveform, sr = torchaudio.load(audio_path) + if sample_rate != sr: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sr, new_freq=sample_rate + ) + all_clips_timepoints = get_clip_timepoints( + clip_sampler, waveform.size(1) / sample_rate + ) + all_clips = [] + for clip_timepoints in all_clips_timepoints: + waveform_clip = waveform[ + :, + int(clip_timepoints[0] * sample_rate) : int( + clip_timepoints[1] * sample_rate + ), + ] + waveform_melspec = waveform2melspec( + waveform_clip, sample_rate, num_mel_bins, target_length + ) + all_clips.append(waveform_melspec) + + normalize = transforms.Normalize(mean=mean, std=std) + all_clips = [normalize(ac).to(device) for ac in all_clips] + + all_clips = torch.stack(all_clips, dim=0) + audio_outputs.append(all_clips) + + return torch.stack(audio_outputs, dim=0) + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Perform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to perform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError("Nothing else supported yet") + + def forward(self, videos): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res + + +def load_and_transform_video_data( + video_paths, + device, + clip_duration=2, + clips_per_video=5, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + pv_transforms.ShortSideScale(224), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) + + for video_path in video_paths: + video = EncodedVideo.from_path( + video_path, + decoder="decord", + decode_audio=False, + **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + + all_video = [video_transform(clip) for clip in all_video] + all_video = SpatialCrop(224, num_crops=3)(all_video) + + all_video = torch.stack(all_video, dim=0) + video_outputs.append(all_video) + + return torch.stack(video_outputs, dim=0).to(device) \ No newline at end of file diff --git a/imagebind/models/__init__.py b/imagebind/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imagebind/models/helper.py b/imagebind/models/helper.py new file mode 100644 index 0000000..514ea46 --- /dev/null +++ b/imagebind/models/helper.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import einops +import numpy as np +import torch + +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x \ No newline at end of file diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py new file mode 100644 index 0000000..83cd641 --- /dev/null +++ b/imagebind/models/image_bind.py @@ -0,0 +1,587 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from functools import partial +from types import SimpleNamespace +from typing import Union, Optional, Tuple, Dict, List + +import torch +import torch.nn as nn +from torch import Tensor + +from imagebind.models.helper import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train +from imagebind.models.multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) +from imagebind.models.multimodal_projectors import create_projectors, create_pre_projector + +from imagebind.models.transformer import MultiheadAttention, SimpleTransformer + +ModalityType = SimpleNamespace( + VISION="Vision", + TEXT="Text", + AUDIO="Audio", + THERMAL="Thermal", + DEPTH="Depth", + IMU="Imu", +) + + +class ImageBindJoiner(nn.Module): + def __init__(self, + vision_query_token_num: int, + vision_qformer_frozen: bool = False, + vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model + vision_pre_dims: List[int] = (), # Projection before Q-Former + vision_post_dims: List[int] = (768, 768) # Projection after Q-Former + ): + super().__init__() + assert not (vision_qformer_frozen and vision_qformer_model == "") + self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims) + self.modality_qformers = self._create_modality_qformers(vision_query_token_num, + vision_qformer_frozen, + vision_qformer_model) + self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims) + + def _create_modality_pre_projectors(self, + vision_pre_dims + ): + modality_pre_projectors = { + ModalityType.VISION: create_projectors(vision_pre_dims) + } + return modality_pre_projectors + + def _create_modality_qformers(self, + vision_query_token_num, + vision_qformer_frozen, + vision_qformer_model + ): + vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num, + freeze_qformer=vision_qformer_frozen, + q_former_model=vision_qformer_model) + modality_qformers = { + ModalityType.VISION: vision_qformer + } + + return nn.ModuleDict(modality_qformers) + + def _create_modality_post_projectors(self, vision_post_dims): + vision_projector = create_projectors(vision_post_dims) + modality_projectors = { + ModalityType.VISION: vision_projector + } + + return nn.ModuleDict(modality_projectors) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + outputs = {} + for modality_key, modality_value in inputs.items(): + assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported." + if modality_value is not None: + modality_value = self.modality_pre_projectors[modality_key](modality_value) + modality_value = self.modality_qformers[modality_key](modality_value) + modality_value = self.modality_post_projectors[modality_key](modality_value) + outputs[modality_key] = modality_value + return outputs + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + super().__init__() + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + + # NOTE: No heads are needed any more. + # head_inputs = modality_value["head"] + # modality_value = self.modality_heads[modality_key]( + # modality_value, **head_inputs + # ) + + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + # NOTE: The reduction operation has been modified. + if reduce_list: + modality_value = modality_value.reshape(B, S, *modality_value[2:]) + modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False, freeze_imagebind=False): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + if not os.path.exists(".checkpoints/imagebind_huge.pth"): + print( + "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." + ) + os.makedirs(".checkpoints", exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + ".checkpoints/imagebind_huge.pth", + progress=True, + ) + + model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + + if freeze_imagebind: + for name, param in model.named_parameters(): + param.requires_grad = False + model = model.eval() + model.train = disabled_train + + return model diff --git a/imagebind/models/multimodal_formers.py b/imagebind/models/multimodal_formers.py new file mode 100644 index 0000000..0567fd6 --- /dev/null +++ b/imagebind/models/multimodal_formers.py @@ -0,0 +1,103 @@ +import logging +import os + +import torch +from torch import nn, Tensor + +from minigpt4.common.dist_utils import download_cached_file +from minigpt4.common.utils import is_url +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class BaseQFormer(nn.Module): + def __init__(self, freeze_qformer=False): + super().__init__() + self.freeze_qformer = freeze_qformer + self.Qformer = None + + def check_and_freeze(self): + assert self.Qformer is not None + if self.freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + logging.info("Freeze This QFormer") + + @classmethod + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +class SequenceGenericQFormer(BaseQFormer): + def __init__(self, + num_query_token: int, + encoder_width: int = 768, + freeze_qformer: bool = False, + q_former_model: str = "", + cross_attention_freq: int = 2 + ): + super().__init__(freeze_qformer) + self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq) + if q_former_model != "": + self.load_Qformer(q_former_model) + self.check_and_freeze() + + def load_Qformer(self, q_former_model): + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + + @classmethod + def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = encoder_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + def forward(self, input_embeds: Tensor) -> Tensor: + input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) + query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=input_embeds, + encoder_attention_mask=input_atts, + return_dict=True, + ) + return query_output.last_hidden_state diff --git a/imagebind/models/multimodal_preprocessors.py b/imagebind/models/multimodal_preprocessors.py new file mode 100644 index 0000000..24f5c05 --- /dev/null +++ b/imagebind/models/multimodal_preprocessors.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional + +import ftfy + +import numpy as np +import regex as re +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ +from imagebind.models.helper import VerboseNNModule, cast_if_src_dtype + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: PatchEmbedGeneric, + img_size: List = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: List = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = imu_stem + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict \ No newline at end of file diff --git a/imagebind/models/multimodal_projectors.py b/imagebind/models/multimodal_projectors.py new file mode 100644 index 0000000..420a00c --- /dev/null +++ b/imagebind/models/multimodal_projectors.py @@ -0,0 +1,45 @@ +from torch import nn, Tensor + +from typing import Union, Optional, Tuple + + +class BaseProjector(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + raise NotImplementedError + + +class LinearProjector(BaseProjector): + def __init__(self, in_dim, out_dim): + super().__init__() + self.fc = nn.Linear(in_dim, out_dim) + + def forward(self, x: Tensor) -> Tensor: + return self.fc(x) + + +class AdapterProjector(BaseProjector): + def __init__(self, in_dim, mid_dim, out_dim): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(in_dim, mid_dim, bias=False), + nn.ReLU(inplace=True), + nn.Linear(mid_dim, out_dim, bias=False), + nn.ReLU(inplace=True) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.fc(x) + + +def create_projectors(dims): + if len(dims) == 0: + return nn.Identity() + elif len(dims) == 2: + return LinearProjector(*dims) + elif len(dims) == 3: + return AdapterProjector(*dims) + else: + raise NotImplementedError diff --git a/imagebind/models/transformer.py b/imagebind/models/transformer.py new file mode 100644 index 0000000..4bb3cfd --- /dev/null +++ b/imagebind/models/transformer.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +import copy +import fnmatch +import logging +from functools import partial +from typing import Callable, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: str = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Callable = None, + post_transformer_layer: Callable = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: List[int] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens \ No newline at end of file diff --git a/minigpt4/configs/models/bindgpt4.yaml b/minigpt4/configs/models/bindgpt4.yaml new file mode 100644 index 0000000..e69de29 diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py new file mode 100644 index 0000000..605d2f2 --- /dev/null +++ b/minigpt4/models/bind_gpt4.py @@ -0,0 +1,200 @@ +import random +from typing import Dict, Tuple + +import torch +from torch import Tensor +from transformers import LlamaTokenizer + +from imagebind.models.image_bind import imagebind_huge, ImageBindJoiner, ModalityType +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import BaseModel +from minigpt4.models.modeling_llama import LlamaForCausalLM + + +@registry.register_model("bind_gpt4") +class BindGPT4(BaseModel): + """ + ImageBind GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/bindgpt4.yaml", + } + + def __init__( + self, + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + freeze_imagebind=True, + freeze_qformer=False, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0 # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__() + assert not low_resource, "Low Resource Mode is Currently Unavailable." + + self.low_resource = low_resource + + print('Loading ImageBind') + self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind) + print('Loading ImageBind Done') + + print('Loading Q-Former and Adapter/Projector') + self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token, + vision_qformer_frozen=freeze_qformer + # vision_qformer_model=q_former_model, + # vision_pre_dims=(1280, 1408) + ) + print('Loading Q-Former and Adapter/Projector Done') + + print('Loading LLAMA') + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + print('Loading LLAMA Done') + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + print("Preparing Prompts") + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + self.prompt_list = [prompt_template.format(p) for p in raw_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + print("Preparing Prompts Done") + + def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + imagebind_outputs = self.multimodal_encoder(inputs) + llama_inputs = self.multimodal_joiner(imagebind_outputs) + return llama_inputs + + def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]: + # TODO: Accept More Modalities. + input_embeds = inputs[modality_name] + attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) + if prompt: + batch_size = input_embeds.shape[0] + p_before, p_after = prompt.split('<{}Here>'.format(modality_name)) + p_before_tokens = self.llama_tokenizer( + p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) + p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) + p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) + wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1) + wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1]) + return wrapped_input_embeds, wrapped_atts_input + else: + return input_embeds, attns_input + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + TODO: More Modalities. + Only accept image inputs here. + Other modalities will conflict with the pre-defined prompt and wrapping strategy. + """ + embeds = self.encode_inputs(inputs) + assert "Vision" in embeds + prompt = random.choice(self.prompt_list) + img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt) + + # NOTE: No modifications from the next line to the end. Except for the autocast part. + + self.llama_tokenizer.padding_side = "right" + + text = [t + self.end_sym for t in inputs["text_input"]] + + to_regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(img_embeds.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], + dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = img_embeds.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = atts_img[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) + attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {"loss": loss} + + @classmethod + def from_config(cls, cfg): + q_former_model = cfg.get("q_former_model", + "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + freeze_imagebind = cfg.get("freeze_imagebind", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", '\n') + + model = cls( + q_former_model=q_former_model, + freeze_imagebind=freeze_imagebind, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load ImageBind-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/prompts/alignment.txt b/prompts/alignment.txt index 38ae75a..90ae57b 100644 --- a/prompts/alignment.txt +++ b/prompts/alignment.txt @@ -1,4 +1,4 @@ - Describe this image in detail. - Take a look at this image and describe what you notice. - Please provide a detailed description of the picture. - Could you describe the contents of this image for me? \ No newline at end of file + Describe this image in detail. + Take a look at this image and describe what you notice. + Please provide a detailed description of the picture. + Could you describe the contents of this image for me? \ No newline at end of file