diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..e7e9d11
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,2 @@
+# Default ignored files
+/workspace.xml
diff --git a/.idea/MiniGPT-4.iml b/.idea/MiniGPT-4.iml
new file mode 100644
index 0000000..4f2c9af
--- /dev/null
+++ b/.idea/MiniGPT-4.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/dictionaries/Desmond.xml b/.idea/dictionaries/Desmond.xml
new file mode 100644
index 0000000..71faffe
--- /dev/null
+++ b/.idea/dictionaries/Desmond.xml
@@ -0,0 +1,7 @@
+
+
+
+ imagebind
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..8656114
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..8ab91a0
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/MiniGPT_4.pdf b/MiniGPT_4.pdf
deleted file mode 100644
index 5450815..0000000
Binary files a/MiniGPT_4.pdf and /dev/null differ
diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml
new file mode 100644
index 0000000..e69de29
diff --git a/examples/ad_1.png b/examples/ad_1.png
deleted file mode 100644
index d0378e4..0000000
Binary files a/examples/ad_1.png and /dev/null differ
diff --git a/examples/ad_2.png b/examples/ad_2.png
deleted file mode 100644
index 674248b..0000000
Binary files a/examples/ad_2.png and /dev/null differ
diff --git a/examples/cook_1.png b/examples/cook_1.png
deleted file mode 100644
index d8cdb45..0000000
Binary files a/examples/cook_1.png and /dev/null differ
diff --git a/examples/cook_2.png b/examples/cook_2.png
deleted file mode 100644
index d08272b..0000000
Binary files a/examples/cook_2.png and /dev/null differ
diff --git a/examples/describe_1.png b/examples/describe_1.png
deleted file mode 100644
index 02f3c92..0000000
Binary files a/examples/describe_1.png and /dev/null differ
diff --git a/examples/describe_2.png b/examples/describe_2.png
deleted file mode 100644
index 20bf8c7..0000000
Binary files a/examples/describe_2.png and /dev/null differ
diff --git a/examples/fact_1.png b/examples/fact_1.png
deleted file mode 100644
index 1f75228..0000000
Binary files a/examples/fact_1.png and /dev/null differ
diff --git a/examples/fact_2.png b/examples/fact_2.png
deleted file mode 100644
index de6ef53..0000000
Binary files a/examples/fact_2.png and /dev/null differ
diff --git a/examples/fix_1.png b/examples/fix_1.png
deleted file mode 100644
index 023cfe6..0000000
Binary files a/examples/fix_1.png and /dev/null differ
diff --git a/examples/fix_2.png b/examples/fix_2.png
deleted file mode 100644
index f60da5f..0000000
Binary files a/examples/fix_2.png and /dev/null differ
diff --git a/examples/fun_1.png b/examples/fun_1.png
deleted file mode 100644
index f720ea6..0000000
Binary files a/examples/fun_1.png and /dev/null differ
diff --git a/examples/fun_2.png b/examples/fun_2.png
deleted file mode 100644
index 1d37a80..0000000
Binary files a/examples/fun_2.png and /dev/null differ
diff --git a/examples/logo_1.png b/examples/logo_1.png
deleted file mode 100644
index 8bbe438..0000000
Binary files a/examples/logo_1.png and /dev/null differ
diff --git a/examples/op_1.png b/examples/op_1.png
deleted file mode 100644
index 3dbb2ff..0000000
Binary files a/examples/op_1.png and /dev/null differ
diff --git a/examples/op_2.png b/examples/op_2.png
deleted file mode 100644
index 2cd3e1f..0000000
Binary files a/examples/op_2.png and /dev/null differ
diff --git a/examples/people_1.png b/examples/people_1.png
deleted file mode 100644
index 7e95c42..0000000
Binary files a/examples/people_1.png and /dev/null differ
diff --git a/examples/people_2.png b/examples/people_2.png
deleted file mode 100644
index aec6c83..0000000
Binary files a/examples/people_2.png and /dev/null differ
diff --git a/examples/rhyme_1.png b/examples/rhyme_1.png
deleted file mode 100644
index 7d13387..0000000
Binary files a/examples/rhyme_1.png and /dev/null differ
diff --git a/examples/rhyme_2.png b/examples/rhyme_2.png
deleted file mode 100644
index 6cf9bf8..0000000
Binary files a/examples/rhyme_2.png and /dev/null differ
diff --git a/examples/story_1.png b/examples/story_1.png
deleted file mode 100644
index 3eb6ccb..0000000
Binary files a/examples/story_1.png and /dev/null differ
diff --git a/examples/story_2.png b/examples/story_2.png
deleted file mode 100644
index 9d37142..0000000
Binary files a/examples/story_2.png and /dev/null differ
diff --git a/examples/web_1.png b/examples/web_1.png
deleted file mode 100644
index 8943842..0000000
Binary files a/examples/web_1.png and /dev/null differ
diff --git a/examples/wop_1.png b/examples/wop_1.png
deleted file mode 100644
index 88f37d6..0000000
Binary files a/examples/wop_1.png and /dev/null differ
diff --git a/examples/wop_2.png b/examples/wop_2.png
deleted file mode 100644
index 8255974..0000000
Binary files a/examples/wop_2.png and /dev/null differ
diff --git a/figs/examples/ad_1.png b/figs/examples/ad_1.png
deleted file mode 100644
index d0378e4..0000000
Binary files a/figs/examples/ad_1.png and /dev/null differ
diff --git a/figs/examples/ad_2.png b/figs/examples/ad_2.png
deleted file mode 100644
index 674248b..0000000
Binary files a/figs/examples/ad_2.png and /dev/null differ
diff --git a/figs/examples/cook_1.png b/figs/examples/cook_1.png
deleted file mode 100644
index d8cdb45..0000000
Binary files a/figs/examples/cook_1.png and /dev/null differ
diff --git a/figs/examples/cook_2.png b/figs/examples/cook_2.png
deleted file mode 100644
index d08272b..0000000
Binary files a/figs/examples/cook_2.png and /dev/null differ
diff --git a/figs/examples/describe_1.png b/figs/examples/describe_1.png
deleted file mode 100644
index 02f3c92..0000000
Binary files a/figs/examples/describe_1.png and /dev/null differ
diff --git a/figs/examples/describe_2.png b/figs/examples/describe_2.png
deleted file mode 100644
index 20bf8c7..0000000
Binary files a/figs/examples/describe_2.png and /dev/null differ
diff --git a/figs/examples/fact_1.png b/figs/examples/fact_1.png
deleted file mode 100644
index 1f75228..0000000
Binary files a/figs/examples/fact_1.png and /dev/null differ
diff --git a/figs/examples/fact_2.png b/figs/examples/fact_2.png
deleted file mode 100644
index de6ef53..0000000
Binary files a/figs/examples/fact_2.png and /dev/null differ
diff --git a/figs/examples/fix_1.png b/figs/examples/fix_1.png
deleted file mode 100644
index 023cfe6..0000000
Binary files a/figs/examples/fix_1.png and /dev/null differ
diff --git a/figs/examples/fix_2.png b/figs/examples/fix_2.png
deleted file mode 100644
index f60da5f..0000000
Binary files a/figs/examples/fix_2.png and /dev/null differ
diff --git a/figs/examples/fun_1.png b/figs/examples/fun_1.png
deleted file mode 100644
index f720ea6..0000000
Binary files a/figs/examples/fun_1.png and /dev/null differ
diff --git a/figs/examples/fun_2.png b/figs/examples/fun_2.png
deleted file mode 100644
index 1d37a80..0000000
Binary files a/figs/examples/fun_2.png and /dev/null differ
diff --git a/figs/examples/logo_1.png b/figs/examples/logo_1.png
deleted file mode 100644
index 8bbe438..0000000
Binary files a/figs/examples/logo_1.png and /dev/null differ
diff --git a/figs/examples/op_1.png b/figs/examples/op_1.png
deleted file mode 100644
index 3dbb2ff..0000000
Binary files a/figs/examples/op_1.png and /dev/null differ
diff --git a/figs/examples/op_2.png b/figs/examples/op_2.png
deleted file mode 100644
index 2cd3e1f..0000000
Binary files a/figs/examples/op_2.png and /dev/null differ
diff --git a/figs/examples/people_1.png b/figs/examples/people_1.png
deleted file mode 100644
index 7e95c42..0000000
Binary files a/figs/examples/people_1.png and /dev/null differ
diff --git a/figs/examples/people_2.png b/figs/examples/people_2.png
deleted file mode 100644
index aec6c83..0000000
Binary files a/figs/examples/people_2.png and /dev/null differ
diff --git a/figs/examples/rhyme_1.png b/figs/examples/rhyme_1.png
deleted file mode 100644
index 7d13387..0000000
Binary files a/figs/examples/rhyme_1.png and /dev/null differ
diff --git a/figs/examples/rhyme_2.png b/figs/examples/rhyme_2.png
deleted file mode 100644
index 6cf9bf8..0000000
Binary files a/figs/examples/rhyme_2.png and /dev/null differ
diff --git a/figs/examples/story_1.png b/figs/examples/story_1.png
deleted file mode 100644
index 3eb6ccb..0000000
Binary files a/figs/examples/story_1.png and /dev/null differ
diff --git a/figs/examples/story_2.png b/figs/examples/story_2.png
deleted file mode 100644
index 9d37142..0000000
Binary files a/figs/examples/story_2.png and /dev/null differ
diff --git a/figs/examples/web_1.png b/figs/examples/web_1.png
deleted file mode 100644
index 8943842..0000000
Binary files a/figs/examples/web_1.png and /dev/null differ
diff --git a/figs/examples/wop_1.png b/figs/examples/wop_1.png
deleted file mode 100644
index 88f37d6..0000000
Binary files a/figs/examples/wop_1.png and /dev/null differ
diff --git a/figs/examples/wop_2.png b/figs/examples/wop_2.png
deleted file mode 100644
index 8255974..0000000
Binary files a/figs/examples/wop_2.png and /dev/null differ
diff --git a/figs/online_demo.png b/figs/online_demo.png
deleted file mode 100644
index 716e438..0000000
Binary files a/figs/online_demo.png and /dev/null differ
diff --git a/figs/overview.png b/figs/overview.png
deleted file mode 100644
index 10b952e..0000000
Binary files a/figs/overview.png and /dev/null differ
diff --git a/imagebind/__init__.py b/imagebind/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/imagebind/data/__init__.py b/imagebind/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/imagebind/data/data_utils.py b/imagebind/data/data_utils.py
new file mode 100644
index 0000000..49ec571
--- /dev/null
+++ b/imagebind/data/data_utils.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+import torchaudio
+import logging
+
+from imagebind.models.multimodal_preprocessors import SimpleTokenizer
+from PIL import Image
+from pytorchvideo import transforms as pv_transforms
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
+from pytorchvideo.data.encoded_video import EncodedVideo
+
+from torchvision import transforms
+from torchvision.transforms._transforms_video import NormalizeVideo
+
+DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
+
+BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
+
+
+def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
+ waveform -= waveform.mean()
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform,
+ htk_compat=True,
+ sample_frequency=sample_rate,
+ use_energy=False,
+ window_type="hanning",
+ num_mel_bins=num_mel_bins,
+ dither=0.0,
+ frame_length=25,
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
+ )
+ # Convert to [mel_bins, num_frames] shape
+ fbank = fbank.transpose(0, 1)
+ # Pad to target_length
+ n_frames = fbank.size(1)
+ p = target_length - n_frames
+ # if p is too large (say >20%), flash a warning
+ if abs(p) / n_frames > 0.2:
+ logging.warning(
+ "Large gap between audio n_frames(%d) and "
+ "target_length (%d). Is the audio_target_length "
+ "setting correct?",
+ n_frames,
+ target_length,
+ )
+ # cut and pad
+ if p > 0:
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
+ elif p < 0:
+ fbank = fbank[:, 0:target_length]
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
+ # channel image
+ fbank = fbank.unsqueeze(0)
+ return fbank
+
+
+def get_clip_timepoints(clip_sampler, duration):
+ # Read out all clips in this video
+ all_clips_timepoints = []
+ is_last_clip = False
+ end = 0.0
+ while not is_last_clip:
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+ all_clips_timepoints.append((start, end))
+ return all_clips_timepoints
+
+
+def load_and_transform_vision_data(image_paths, device):
+ if image_paths is None:
+ return None
+
+ image_ouputs = []
+ for image_path in image_paths:
+ data_transform = transforms.Compose(
+ [
+ transforms.Resize(
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )
+ with open(image_path, "rb") as fopen:
+ image = Image.open(fopen).convert("RGB")
+
+ image = data_transform(image).to(device)
+ image_ouputs.append(image)
+ return torch.stack(image_ouputs, dim=0)
+
+
+def load_and_transform_text(text, device):
+ if text is None:
+ return None
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
+ tokens = torch.cat(tokens, dim=0)
+ return tokens
+
+
+def load_and_transform_audio_data(
+ audio_paths,
+ device,
+ num_mel_bins=128,
+ target_length=204,
+ sample_rate=16000,
+ clip_duration=2,
+ clips_per_video=3,
+ mean=-4.268,
+ std=9.138,
+):
+ if audio_paths is None:
+ return None
+
+ audio_outputs = []
+ clip_sampler = ConstantClipsPerVideoSampler(
+ clip_duration=clip_duration, clips_per_video=clips_per_video
+ )
+
+ for audio_path in audio_paths:
+ waveform, sr = torchaudio.load(audio_path)
+ if sample_rate != sr:
+ waveform = torchaudio.functional.resample(
+ waveform, orig_freq=sr, new_freq=sample_rate
+ )
+ all_clips_timepoints = get_clip_timepoints(
+ clip_sampler, waveform.size(1) / sample_rate
+ )
+ all_clips = []
+ for clip_timepoints in all_clips_timepoints:
+ waveform_clip = waveform[
+ :,
+ int(clip_timepoints[0] * sample_rate) : int(
+ clip_timepoints[1] * sample_rate
+ ),
+ ]
+ waveform_melspec = waveform2melspec(
+ waveform_clip, sample_rate, num_mel_bins, target_length
+ )
+ all_clips.append(waveform_melspec)
+
+ normalize = transforms.Normalize(mean=mean, std=std)
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
+
+ all_clips = torch.stack(all_clips, dim=0)
+ audio_outputs.append(all_clips)
+
+ return torch.stack(audio_outputs, dim=0)
+
+
+def get_clip_timepoints(clip_sampler, duration):
+ # Read out all clips in this video
+ all_clips_timepoints = []
+ is_last_clip = False
+ end = 0.0
+ while not is_last_clip:
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+ all_clips_timepoints.append((start, end))
+ return all_clips_timepoints
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+ """
+ Perform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = boxes.copy()
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ scale_size (int): optinal. If not None, resize the images to scale_size before
+ performing any crop.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ ndim = len(images.shape)
+ if ndim == 3:
+ images = images.unsqueeze(0)
+ height = images.shape[2]
+ width = images.shape[3]
+
+ if scale_size is not None:
+ if width <= height:
+ width, height = scale_size, int(height / width * scale_size)
+ else:
+ width, height = int(width / height * scale_size), scale_size
+ images = torch.nn.functional.interpolate(
+ images,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ if ndim == 3:
+ cropped = cropped.squeeze(0)
+ return cropped, cropped_boxes
+
+
+class SpatialCrop(nn.Module):
+ """
+ Convert the video into 3 smaller clips spatially. Must be used after the
+ temporal crops to get spatial crops, and should be used with
+ -2 in the spatial crop at the slowfast augmentation stage (so full
+ frames are passed in here). Will return a larger list with the
+ 3x spatial crops as well.
+ """
+
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
+ super().__init__()
+ self.crop_size = crop_size
+ if num_crops == 3:
+ self.crops_to_ext = [0, 1, 2]
+ self.flipped_crops_to_ext = []
+ elif num_crops == 1:
+ self.crops_to_ext = [1]
+ self.flipped_crops_to_ext = []
+ else:
+ raise NotImplementedError("Nothing else supported yet")
+
+ def forward(self, videos):
+ """
+ Args:
+ videos: A list of C, T, H, W videos.
+ Returns:
+ videos: A list with 3x the number of elements. Each video converted
+ to C, T, H', W' by spatial cropping.
+ """
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
+ res = []
+ for video in videos:
+ for spatial_idx in self.crops_to_ext:
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
+ if not self.flipped_crops_to_ext:
+ continue
+ flipped_video = transforms.functional.hflip(video)
+ for spatial_idx in self.flipped_crops_to_ext:
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
+ return res
+
+
+def load_and_transform_video_data(
+ video_paths,
+ device,
+ clip_duration=2,
+ clips_per_video=5,
+ sample_rate=16000,
+):
+ if video_paths is None:
+ return None
+
+ video_outputs = []
+ video_transform = transforms.Compose(
+ [
+ pv_transforms.ShortSideScale(224),
+ NormalizeVideo(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )
+
+ clip_sampler = ConstantClipsPerVideoSampler(
+ clip_duration=clip_duration, clips_per_video=clips_per_video
+ )
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
+
+ for video_path in video_paths:
+ video = EncodedVideo.from_path(
+ video_path,
+ decoder="decord",
+ decode_audio=False,
+ **{"sample_rate": sample_rate},
+ )
+
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
+
+ all_video = []
+ for clip_timepoints in all_clips_timepoints:
+ # Read the clip, get frames
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
+ if clip is None:
+ raise ValueError("No clip found")
+ video_clip = frame_sampler(clip["video"])
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
+
+ all_video.append(video_clip)
+
+ all_video = [video_transform(clip) for clip in all_video]
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
+
+ all_video = torch.stack(all_video, dim=0)
+ video_outputs.append(all_video)
+
+ return torch.stack(video_outputs, dim=0).to(device)
\ No newline at end of file
diff --git a/imagebind/models/__init__.py b/imagebind/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/imagebind/models/helper.py b/imagebind/models/helper.py
new file mode 100644
index 0000000..514ea46
--- /dev/null
+++ b/imagebind/models/helper.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import einops
+import numpy as np
+import torch
+
+import torch.nn as nn
+
+
+class Normalize(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
+
+
+class LearnableLogitScaling(nn.Module):
+ def __init__(
+ self,
+ logit_scale_init: float = 1 / 0.07,
+ learnable: bool = True,
+ max_logit_scale: float = 100,
+ ) -> None:
+ super().__init__()
+ self.max_logit_scale = max_logit_scale
+ self.logit_scale_init = logit_scale_init
+ self.learnable = learnable
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
+ if learnable:
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
+ else:
+ self.register_buffer("log_logit_scale", log_logit_scale)
+
+ def forward(self, x):
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
+
+ def extra_repr(self):
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
+ return st
+
+
+class EinOpsRearrange(nn.Module):
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
+ super().__init__()
+ self.rearrange_expr = rearrange_expr
+ self.kwargs = kwargs
+
+ def forward(self, x):
+ assert isinstance(x, torch.Tensor)
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
+
+
+class VerboseNNModule(nn.Module):
+ """
+ Wrapper around nn.Module that prints registered buffers and parameter names.
+ """
+
+ @staticmethod
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
+ st = (
+ "("
+ + name
+ + "): "
+ + "tensor("
+ + str(tuple(tensor[1].shape))
+ + ", requires_grad="
+ + str(tensor[1].requires_grad)
+ + ")\n"
+ )
+ return st
+
+ def extra_repr(self) -> str:
+ named_modules = set()
+ for p in self.named_modules():
+ named_modules.update([p[0]])
+ named_modules = list(named_modules)
+
+ string_repr = ""
+ for p in self.named_parameters():
+ name = p[0].split(".")[0]
+ if name not in named_modules:
+ string_repr += self.get_readable_tensor_repr(name, p)
+
+ for p in self.named_buffers():
+ name = p[0].split(".")[0]
+ string_repr += self.get_readable_tensor_repr(name, p)
+
+ return string_repr
+
+
+def cast_if_src_dtype(
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
+):
+ updated = False
+ if tensor.dtype == src_dtype:
+ tensor = tensor.to(dtype=tgt_dtype)
+ updated = True
+ return tensor, updated
+
+
+class QuickGELU(nn.Module):
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class SelectElement(nn.Module):
+ def __init__(self, index) -> None:
+ super().__init__()
+ self.index = index
+
+ def forward(self, x):
+ assert x.ndim >= 3
+ return x[:, self.index, ...]
+
+
+class SelectEOSAndProject(nn.Module):
+ """
+ Text Pooling used in OpenCLIP
+ """
+
+ def __init__(self, proj: nn.Module) -> None:
+ super().__init__()
+ self.proj = proj
+
+ def forward(self, x, seq_len):
+ assert x.ndim == 3
+ # x is of shape B x L x D
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), seq_len]
+ x = self.proj(x)
+ return x
\ No newline at end of file
diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py
new file mode 100644
index 0000000..83cd641
--- /dev/null
+++ b/imagebind/models/image_bind.py
@@ -0,0 +1,587 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import os
+from functools import partial
+from types import SimpleNamespace
+from typing import Union, Optional, Tuple, Dict, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from imagebind.models.helper import (
+ EinOpsRearrange,
+ LearnableLogitScaling,
+ Normalize,
+ SelectElement,
+ SelectEOSAndProject,
+)
+from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train
+from imagebind.models.multimodal_preprocessors import (
+ AudioPreprocessor,
+ IMUPreprocessor,
+ PadIm2Video,
+ PatchEmbedGeneric,
+ RGBDTPreprocessor,
+ SpatioTemporalPosEmbeddingHelper,
+ TextPreprocessor,
+ ThermalPreprocessor,
+)
+from imagebind.models.multimodal_projectors import create_projectors, create_pre_projector
+
+from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
+
+ModalityType = SimpleNamespace(
+ VISION="Vision",
+ TEXT="Text",
+ AUDIO="Audio",
+ THERMAL="Thermal",
+ DEPTH="Depth",
+ IMU="Imu",
+)
+
+
+class ImageBindJoiner(nn.Module):
+ def __init__(self,
+ vision_query_token_num: int,
+ vision_qformer_frozen: bool = False,
+ vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
+ vision_pre_dims: List[int] = (), # Projection before Q-Former
+ vision_post_dims: List[int] = (768, 768) # Projection after Q-Former
+ ):
+ super().__init__()
+ assert not (vision_qformer_frozen and vision_qformer_model == "")
+ self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
+ self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
+ vision_qformer_frozen,
+ vision_qformer_model)
+ self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
+
+ def _create_modality_pre_projectors(self,
+ vision_pre_dims
+ ):
+ modality_pre_projectors = {
+ ModalityType.VISION: create_projectors(vision_pre_dims)
+ }
+ return modality_pre_projectors
+
+ def _create_modality_qformers(self,
+ vision_query_token_num,
+ vision_qformer_frozen,
+ vision_qformer_model
+ ):
+ vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
+ freeze_qformer=vision_qformer_frozen,
+ q_former_model=vision_qformer_model)
+ modality_qformers = {
+ ModalityType.VISION: vision_qformer
+ }
+
+ return nn.ModuleDict(modality_qformers)
+
+ def _create_modality_post_projectors(self, vision_post_dims):
+ vision_projector = create_projectors(vision_post_dims)
+ modality_projectors = {
+ ModalityType.VISION: vision_projector
+ }
+
+ return nn.ModuleDict(modality_projectors)
+
+ def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ outputs = {}
+ for modality_key, modality_value in inputs.items():
+ assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
+ if modality_value is not None:
+ modality_value = self.modality_pre_projectors[modality_key](modality_value)
+ modality_value = self.modality_qformers[modality_key](modality_value)
+ modality_value = self.modality_post_projectors[modality_key](modality_value)
+ outputs[modality_key] = modality_value
+ return outputs
+
+
+class ImageBindModel(nn.Module):
+ def __init__(
+ self,
+ video_frames=2,
+ kernel_size=(2, 14, 14),
+ audio_kernel_size=16,
+ audio_stride=10,
+ out_embed_dim=768,
+ vision_embed_dim=1024,
+ vision_num_blocks=24,
+ vision_num_heads=16,
+ audio_embed_dim=768,
+ audio_num_blocks=12,
+ audio_num_heads=12,
+ audio_num_mel_bins=128,
+ audio_target_len=204,
+ audio_drop_path=0.1,
+ text_embed_dim=768,
+ text_num_blocks=12,
+ text_num_heads=12,
+ depth_embed_dim=384,
+ depth_kernel_size=16,
+ depth_num_blocks=12,
+ depth_num_heads=8,
+ depth_drop_path=0.0,
+ thermal_embed_dim=768,
+ thermal_kernel_size=16,
+ thermal_num_blocks=12,
+ thermal_num_heads=12,
+ thermal_drop_path=0.0,
+ imu_embed_dim=512,
+ imu_kernel_size=8,
+ imu_num_blocks=6,
+ imu_num_heads=8,
+ imu_drop_path=0.7,
+ ):
+ super().__init__()
+
+ self.modality_preprocessors = self._create_modality_preprocessors(
+ video_frames,
+ vision_embed_dim,
+ kernel_size,
+ text_embed_dim,
+ audio_embed_dim,
+ audio_kernel_size,
+ audio_stride,
+ audio_num_mel_bins,
+ audio_target_len,
+ depth_embed_dim,
+ depth_kernel_size,
+ thermal_embed_dim,
+ thermal_kernel_size,
+ imu_embed_dim,
+ )
+
+ self.modality_trunks = self._create_modality_trunks(
+ vision_embed_dim,
+ vision_num_blocks,
+ vision_num_heads,
+ text_embed_dim,
+ text_num_blocks,
+ text_num_heads,
+ audio_embed_dim,
+ audio_num_blocks,
+ audio_num_heads,
+ audio_drop_path,
+ depth_embed_dim,
+ depth_num_blocks,
+ depth_num_heads,
+ depth_drop_path,
+ thermal_embed_dim,
+ thermal_num_blocks,
+ thermal_num_heads,
+ thermal_drop_path,
+ imu_embed_dim,
+ imu_num_blocks,
+ imu_num_heads,
+ imu_drop_path,
+ )
+
+ self.modality_heads = self._create_modality_heads(
+ out_embed_dim,
+ vision_embed_dim,
+ text_embed_dim,
+ audio_embed_dim,
+ depth_embed_dim,
+ thermal_embed_dim,
+ imu_embed_dim,
+ )
+
+ self.modality_postprocessors = self._create_modality_postprocessors(
+ out_embed_dim
+ )
+
+ def _create_modality_preprocessors(
+ self,
+ video_frames=2,
+ vision_embed_dim=1024,
+ kernel_size=(2, 14, 14),
+ text_embed_dim=768,
+ audio_embed_dim=768,
+ audio_kernel_size=16,
+ audio_stride=10,
+ audio_num_mel_bins=128,
+ audio_target_len=204,
+ depth_embed_dim=768,
+ depth_kernel_size=16,
+ thermal_embed_dim=768,
+ thermal_kernel_size=16,
+ imu_embed_dim=512,
+ ):
+ rgbt_stem = PatchEmbedGeneric(
+ proj_stem=[
+ PadIm2Video(pad_type="repeat", ntimes=2),
+ nn.Conv3d(
+ in_channels=3,
+ kernel_size=kernel_size,
+ out_channels=vision_embed_dim,
+ stride=kernel_size,
+ bias=False,
+ ),
+ ]
+ )
+ rgbt_preprocessor = RGBDTPreprocessor(
+ img_size=[3, video_frames, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ rgbt_stem=rgbt_stem,
+ depth_stem=None,
+ )
+
+ text_preprocessor = TextPreprocessor(
+ context_length=77,
+ vocab_size=49408,
+ embed_dim=text_embed_dim,
+ causal_masking=True,
+ )
+
+ audio_stem = PatchEmbedGeneric(
+ proj_stem=[
+ nn.Conv2d(
+ in_channels=1,
+ kernel_size=audio_kernel_size,
+ stride=audio_stride,
+ out_channels=audio_embed_dim,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
+ )
+ audio_preprocessor = AudioPreprocessor(
+ img_size=[1, audio_num_mel_bins, audio_target_len],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ audio_stem=audio_stem,
+ )
+
+ depth_stem = PatchEmbedGeneric(
+ [
+ nn.Conv2d(
+ kernel_size=depth_kernel_size,
+ in_channels=1,
+ out_channels=depth_embed_dim,
+ stride=depth_kernel_size,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
+ )
+
+ depth_preprocessor = RGBDTPreprocessor(
+ img_size=[1, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ rgbt_stem=None,
+ depth_stem=depth_stem,
+ )
+
+ thermal_stem = PatchEmbedGeneric(
+ [
+ nn.Conv2d(
+ kernel_size=thermal_kernel_size,
+ in_channels=1,
+ out_channels=thermal_embed_dim,
+ stride=thermal_kernel_size,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
+ )
+ thermal_preprocessor = ThermalPreprocessor(
+ img_size=[1, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ thermal_stem=thermal_stem,
+ )
+
+ imu_stem = PatchEmbedGeneric(
+ [
+ nn.Linear(
+ in_features=48,
+ out_features=imu_embed_dim,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
+ )
+
+ imu_preprocessor = IMUPreprocessor(
+ img_size=[6, 2000],
+ num_cls_tokens=1,
+ kernel_size=8,
+ embed_dim=imu_embed_dim,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ imu_stem=imu_stem,
+ )
+
+ modality_preprocessors = {
+ ModalityType.VISION: rgbt_preprocessor,
+ ModalityType.TEXT: text_preprocessor,
+ ModalityType.AUDIO: audio_preprocessor,
+ ModalityType.DEPTH: depth_preprocessor,
+ ModalityType.THERMAL: thermal_preprocessor,
+ ModalityType.IMU: imu_preprocessor,
+ }
+
+ return nn.ModuleDict(modality_preprocessors)
+
+ def _create_modality_trunks(
+ self,
+ vision_embed_dim=1024,
+ vision_num_blocks=24,
+ vision_num_heads=16,
+ text_embed_dim=768,
+ text_num_blocks=12,
+ text_num_heads=12,
+ audio_embed_dim=768,
+ audio_num_blocks=12,
+ audio_num_heads=12,
+ audio_drop_path=0.0,
+ depth_embed_dim=768,
+ depth_num_blocks=12,
+ depth_num_heads=12,
+ depth_drop_path=0.0,
+ thermal_embed_dim=768,
+ thermal_num_blocks=12,
+ thermal_num_heads=12,
+ thermal_drop_path=0.0,
+ imu_embed_dim=512,
+ imu_num_blocks=6,
+ imu_num_heads=8,
+ imu_drop_path=0.7,
+ ):
+ def instantiate_trunk(
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
+ ):
+ return SimpleTransformer(
+ embed_dim=embed_dim,
+ num_blocks=num_blocks,
+ ffn_dropout_rate=0.0,
+ drop_path_rate=drop_path,
+ attn_target=partial(
+ MultiheadAttention,
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ bias=True,
+ add_bias_kv=add_bias_kv,
+ ),
+ pre_transformer_layer=nn.Sequential(
+ nn.LayerNorm(embed_dim, eps=1e-6)
+ if pre_transformer_ln
+ else nn.Identity(),
+ EinOpsRearrange("b l d -> l b d"),
+ ),
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
+ )
+
+ modality_trunks = {}
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
+ vision_embed_dim,
+ vision_num_blocks,
+ vision_num_heads,
+ pre_transformer_ln=True,
+ add_bias_kv=False,
+ drop_path=0.0,
+ )
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
+ text_embed_dim,
+ text_num_blocks,
+ text_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=False,
+ drop_path=0.0,
+ )
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
+ audio_embed_dim,
+ audio_num_blocks,
+ audio_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=audio_drop_path,
+ )
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
+ depth_embed_dim,
+ depth_num_blocks,
+ depth_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=depth_drop_path,
+ )
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
+ thermal_embed_dim,
+ thermal_num_blocks,
+ thermal_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=thermal_drop_path,
+ )
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
+ imu_embed_dim,
+ imu_num_blocks,
+ imu_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=imu_drop_path,
+ )
+
+ return nn.ModuleDict(modality_trunks)
+
+ def _create_modality_heads(
+ self,
+ out_embed_dim,
+ vision_embed_dim,
+ text_embed_dim,
+ audio_embed_dim,
+ depth_embed_dim,
+ thermal_embed_dim,
+ imu_embed_dim,
+ ):
+ modality_heads = {}
+
+ modality_heads[ModalityType.VISION] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
+ proj=nn.Sequential(
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
+ )
+ )
+
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.IMU] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Dropout(p=0.5),
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
+ )
+
+ return nn.ModuleDict(modality_heads)
+
+ def _create_modality_postprocessors(self, out_embed_dim):
+ modality_postprocessors = {}
+
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
+ )
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
+ )
+
+ return nn.ModuleDict(modality_postprocessors)
+
+ def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ outputs = {}
+ for modality_key, modality_value in inputs.items():
+ reduce_list = (
+ modality_value.ndim >= 5
+ ) # Audio and Video inputs consist of multiple clips
+ if reduce_list:
+ B, S = modality_value.shape[:2]
+ modality_value = modality_value.reshape(
+ B * S, *modality_value.shape[2:]
+ )
+
+ if modality_value is not None:
+ modality_value = self.modality_preprocessors[modality_key](
+ **{modality_key: modality_value}
+ )
+ trunk_inputs = modality_value["trunk"]
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
+
+ # NOTE: No heads are needed any more.
+ # head_inputs = modality_value["head"]
+ # modality_value = self.modality_heads[modality_key](
+ # modality_value, **head_inputs
+ # )
+
+ modality_value = self.modality_postprocessors[modality_key](
+ modality_value
+ )
+
+ # NOTE: The reduction operation has been modified.
+ if reduce_list:
+ modality_value = modality_value.reshape(B, S, *modality_value[2:])
+ modality_value = modality_value.mean(dim=1)
+
+ outputs[modality_key] = modality_value
+
+ return outputs
+
+
+def imagebind_huge(pretrained=False, freeze_imagebind=False):
+ model = ImageBindModel(
+ vision_embed_dim=1280,
+ vision_num_blocks=32,
+ vision_num_heads=16,
+ text_embed_dim=1024,
+ text_num_blocks=24,
+ text_num_heads=16,
+ out_embed_dim=1024,
+ audio_drop_path=0.1,
+ imu_drop_path=0.7,
+ )
+
+ if pretrained:
+ if not os.path.exists(".checkpoints/imagebind_huge.pth"):
+ print(
+ "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
+ )
+ os.makedirs(".checkpoints", exist_ok=True)
+ torch.hub.download_url_to_file(
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
+ ".checkpoints/imagebind_huge.pth",
+ progress=True,
+ )
+
+ model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
+
+ if freeze_imagebind:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model = model.eval()
+ model.train = disabled_train
+
+ return model
diff --git a/imagebind/models/multimodal_formers.py b/imagebind/models/multimodal_formers.py
new file mode 100644
index 0000000..0567fd6
--- /dev/null
+++ b/imagebind/models/multimodal_formers.py
@@ -0,0 +1,103 @@
+import logging
+import os
+
+import torch
+from torch import nn, Tensor
+
+from minigpt4.common.dist_utils import download_cached_file
+from minigpt4.common.utils import is_url
+from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class BaseQFormer(nn.Module):
+ def __init__(self, freeze_qformer=False):
+ super().__init__()
+ self.freeze_qformer = freeze_qformer
+ self.Qformer = None
+
+ def check_and_freeze(self):
+ assert self.Qformer is not None
+ if self.freeze_qformer:
+ for name, param in self.Qformer.named_parameters():
+ param.requires_grad = False
+ self.Qformer = self.Qformer.eval()
+ self.Qformer.train = disabled_train
+ self.query_tokens.requires_grad = False
+ logging.info("Freeze This QFormer")
+
+ @classmethod
+ def load_from_pretrained(self, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ # logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+
+class SequenceGenericQFormer(BaseQFormer):
+ def __init__(self,
+ num_query_token: int,
+ encoder_width: int = 768,
+ freeze_qformer: bool = False,
+ q_former_model: str = "",
+ cross_attention_freq: int = 2
+ ):
+ super().__init__(freeze_qformer)
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq)
+ if q_former_model != "":
+ self.load_Qformer(q_former_model)
+ self.check_and_freeze()
+
+ def load_Qformer(self, q_former_model):
+ self.Qformer.cls = None
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ self.load_from_pretrained(url_or_filename=q_former_model)
+
+ @classmethod
+ def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = encoder_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ return Qformer, query_tokens
+
+ def forward(self, input_embeds: Tensor) -> Tensor:
+ input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
+ query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=input_embeds,
+ encoder_attention_mask=input_atts,
+ return_dict=True,
+ )
+ return query_output.last_hidden_state
diff --git a/imagebind/models/multimodal_preprocessors.py b/imagebind/models/multimodal_preprocessors.py
new file mode 100644
index 0000000..24f5c05
--- /dev/null
+++ b/imagebind/models/multimodal_preprocessors.py
@@ -0,0 +1,686 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import gzip
+import html
+import io
+import math
+from functools import lru_cache
+from typing import Callable, List, Optional
+
+import ftfy
+
+import numpy as np
+import regex as re
+import torch
+import torch.nn as nn
+from iopath.common.file_io import g_pathmgr
+from timm.models.layers import trunc_normal_
+from imagebind.models.helper import VerboseNNModule, cast_if_src_dtype
+
+
+def get_sinusoid_encoding_table(n_position, d_hid):
+ """Sinusoid position encoding table"""
+
+ # TODO: make it with torch instead of numpy
+ def get_position_angle_vec(position):
+ return [
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ]
+
+ sinusoid_table = np.array(
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
+ )
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
+ N = pos_embed.shape[1]
+ if N == target_spatial_size:
+ return pos_embed
+ dim = pos_embed.shape[-1]
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
+ 0, 3, 1, 2
+ ),
+ scale_factor=math.sqrt(target_spatial_size / N),
+ mode="bicubic",
+ )
+ if updated:
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return pos_embed
+
+
+def interpolate_pos_encoding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape=None,
+ first_patch_idx=1,
+):
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
+ if npatch_per_img == N:
+ return pos_embed
+
+ assert (
+ patches_layout[-1] == patches_layout[-2]
+ ), "Interpolation of pos embed not supported for non-square layouts"
+
+ class_emb = pos_embed[:, :first_patch_idx]
+ pos_embed = pos_embed[:, first_patch_idx:]
+
+ if input_shape is None or patches_layout[0] == 1:
+ # simple 2D pos embedding, no temporal component
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
+ elif patches_layout[0] > 1:
+ # pos embed has a temporal component
+ assert len(input_shape) == 4, "temporal interpolation not supported"
+ # we only support 2D interpolation in this case
+ num_frames = patches_layout[0]
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
+ # interpolate embedding for zeroth frame
+ pos_embed = interpolate_pos_encoding_2d(
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
+ )
+ else:
+ raise ValueError("This type of interpolation isn't implemented")
+
+ return torch.cat((class_emb, pos_embed), dim=1)
+
+
+def _get_pos_embedding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape,
+ first_patch_idx=1,
+):
+ pos_embed = interpolate_pos_encoding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape=input_shape,
+ first_patch_idx=first_patch_idx,
+ )
+ return pos_embed
+
+
+class PatchEmbedGeneric(nn.Module):
+ """
+ PatchEmbed from Hydra
+ """
+
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
+ super().__init__()
+
+ if len(proj_stem) > 1:
+ self.proj = nn.Sequential(*proj_stem)
+ else:
+ # Special case to be able to load pre-trained models that were
+ # trained with a standard stem
+ self.proj = proj_stem[0]
+ self.norm_layer = norm_layer
+
+ def get_patch_layout(self, img_size):
+ with torch.no_grad():
+ dummy_img = torch.zeros(
+ [
+ 1,
+ ]
+ + img_size
+ )
+ dummy_out = self.proj(dummy_img)
+ embed_dim = dummy_out.shape[1]
+ patches_layout = tuple(dummy_out.shape[2:])
+ num_patches = np.prod(patches_layout)
+ return patches_layout, num_patches, embed_dim
+
+ def forward(self, x):
+ x = self.proj(x)
+ # B C (T) H W -> B (T)HW C
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm_layer is not None:
+ x = self.norm_layer(x)
+ return x
+
+
+class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
+ def __init__(
+ self,
+ patches_layout: List,
+ num_patches: int,
+ num_cls_tokens: int,
+ embed_dim: int,
+ learnable: bool,
+ ) -> None:
+ super().__init__()
+ self.num_cls_tokens = num_cls_tokens
+ self.patches_layout = patches_layout
+ self.num_patches = num_patches
+ self.num_tokens = num_cls_tokens + num_patches
+ self.learnable = learnable
+ if self.learnable:
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
+ trunc_normal_(self.pos_embed, std=0.02)
+ else:
+ self.register_buffer(
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
+ )
+
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
+ input_shape = vision_input.shape
+ pos_embed = _get_pos_embedding(
+ all_vision_tokens.size(1) - self.num_cls_tokens,
+ pos_embed=self.pos_embed,
+ patches_layout=self.patches_layout,
+ input_shape=input_shape,
+ first_patch_idx=self.num_cls_tokens,
+ )
+ return pos_embed
+
+
+class RGBDTPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ rgbt_stem: PatchEmbedGeneric,
+ depth_stem: PatchEmbedGeneric,
+ img_size: List = (3, 224, 224),
+ num_cls_tokens: int = 1,
+ pos_embed_fn: Callable = None,
+ use_type_embed: bool = False,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
+ (
+ self.patches_layout,
+ self.num_patches,
+ self.embed_dim,
+ ) = stem.get_patch_layout(img_size)
+ self.rgbt_stem = rgbt_stem
+ self.depth_stem = depth_stem
+ self.use_pos_embed = pos_embed_fn is not None
+ self.use_type_embed = use_type_embed
+ self.num_cls_tokens = num_cls_tokens
+
+ if self.use_pos_embed:
+ self.pos_embedding_helper = pos_embed_fn(
+ patches_layout=self.patches_layout,
+ num_cls_tokens=num_cls_tokens,
+ num_patches=self.num_patches,
+ embed_dim=self.embed_dim,
+ )
+ if self.num_cls_tokens > 0:
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
+ )
+ if self.use_type_embed:
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style):
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+ if self.use_pos_embed:
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
+ self.pos_embedding_helper.pos_embed *= scale
+
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ if self.use_type_embed:
+ nn.init.normal_(self.type_embed)
+
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
+ # tokens is of shape B x L x D
+ tokens = stem(input)
+ assert tokens.ndim == 3
+ assert tokens.shape[2] == self.embed_dim
+ B = tokens.shape[0]
+ if self.num_cls_tokens > 0:
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ tokens = torch.cat((class_tokens, tokens), dim=1)
+ if self.use_pos_embed:
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
+ tokens = tokens + pos_embed
+ if self.use_type_embed:
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
+ return tokens
+
+ def forward(self, vision=None, depth=None, patch_mask=None):
+ if patch_mask is not None:
+ raise NotImplementedError()
+
+ if vision is not None:
+ vision_tokens = self.tokenize_input_and_cls_pos(
+ vision, self.rgbt_stem, patch_mask
+ )
+
+ if depth is not None:
+ depth_tokens = self.tokenize_input_and_cls_pos(
+ depth, self.depth_stem, patch_mask
+ )
+
+ # aggregate tokens
+ if vision is not None and depth is not None:
+ final_tokens = vision_tokens + depth_tokens
+ else:
+ final_tokens = vision_tokens if vision is not None else depth_tokens
+ return_dict = {
+ "trunk": {
+ "tokens": final_tokens,
+ },
+ "head": {},
+ }
+ return return_dict
+
+
+class AudioPreprocessor(RGBDTPreprocessor):
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
+
+ def forward(self, audio=None):
+ return super().forward(vision=audio)
+
+
+class ThermalPreprocessor(RGBDTPreprocessor):
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
+
+ def forward(self, thermal=None):
+ return super().forward(vision=thermal)
+
+
+def build_causal_attention_mask(context_length):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(context_length, context_length, requires_grad=False)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+
+class TextPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ vocab_size: int,
+ context_length: int,
+ embed_dim: int,
+ causal_masking: bool,
+ supply_seq_len_to_head: bool = True,
+ num_cls_tokens: int = 0,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.context_length = context_length
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
+ self.pos_embed = nn.Parameter(
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
+ )
+ self.causal_masking = causal_masking
+ if self.causal_masking:
+ mask = build_causal_attention_mask(self.context_length)
+ # register the mask as a buffer so it can be moved to the right device
+ self.register_buffer("mask", mask)
+
+ self.supply_seq_len_to_head = supply_seq_len_to_head
+ self.num_cls_tokens = num_cls_tokens
+ self.embed_dim = embed_dim
+ if num_cls_tokens > 0:
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
+ )
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style="openclip"):
+ # OpenCLIP style initialization
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.pos_embed, std=0.01)
+
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ def forward(self, text):
+ # text tokens are of shape B x L x D
+ text_tokens = self.token_embedding(text)
+ # concat CLS tokens if any
+ if self.num_cls_tokens > 0:
+ B = text_tokens.shape[0]
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
+ text_tokens = text_tokens + self.pos_embed
+ return_dict = {
+ "trunk": {
+ "tokens": text_tokens,
+ },
+ "head": {},
+ }
+ # Compute sequence length after adding CLS tokens
+ if self.supply_seq_len_to_head:
+ text_lengths = text.argmax(dim=-1)
+ return_dict["head"] = {
+ "seq_len": text_lengths,
+ }
+ if self.causal_masking:
+ return_dict["trunk"].update({"attn_mask": self.mask})
+ return return_dict
+
+
+class Im2Video(nn.Module):
+ """Convert an image into a trivial video."""
+
+ def __init__(self, time_dim=2):
+ super().__init__()
+ self.time_dim = time_dim
+
+ def forward(self, x):
+ if x.ndim == 4:
+ # B, C, H, W -> B, C, T, H, W
+ return x.unsqueeze(self.time_dim)
+ elif x.ndim == 5:
+ return x
+ else:
+ raise ValueError(f"Dimension incorrect {x.shape}")
+
+
+class PadIm2Video(Im2Video):
+ def __init__(self, ntimes, pad_type, time_dim=2):
+ super().__init__(time_dim=time_dim)
+ assert ntimes > 0
+ assert pad_type in ["zero", "repeat"]
+ self.ntimes = ntimes
+ self.pad_type = pad_type
+
+ def forward(self, x):
+ x = super().forward(x)
+ if x.shape[self.time_dim] == 1:
+ if self.pad_type == "repeat":
+ new_shape = [1] * len(x.shape)
+ new_shape[self.time_dim] = self.ntimes
+ x = x.repeat(new_shape)
+ elif self.pad_type == "zero":
+ padarg = [0, 0] * len(x.shape)
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
+ x = nn.functional.pad(x, padarg)
+ return x
+
+
+# Modified from github.com/openai/CLIP
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str, context_length=77):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+
+ with g_pathmgr.open(bpe_path, "rb") as fh:
+ bpe_bytes = io.BytesIO(fh.read())
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ "<|startoftext|>": "<|startoftext|>",
+ "<|endoftext|>": "<|endoftext|>",
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE,
+ )
+ self.context_length = context_length
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = (
+ bytearray([self.byte_decoder[c] for c in text])
+ .decode("utf-8", errors="replace")
+ .replace("", " ")
+ )
+ return text
+
+ def __call__(self, texts, context_length=None):
+ if not context_length:
+ context_length = self.context_length
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder["<|startoftext|>"]
+ eot_token = self.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ tokens = tokens[:context_length]
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ if len(result) == 1:
+ return result[0]
+ return result
+
+
+class IMUPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ kernel_size: int,
+ imu_stem: PatchEmbedGeneric,
+ embed_dim: int,
+ img_size: List = (6, 2000),
+ num_cls_tokens: int = 1,
+ pos_embed_fn: Callable = None,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ stem = imu_stem
+ self.imu_stem = imu_stem
+ self.embed_dim = embed_dim
+ self.use_pos_embed = pos_embed_fn is not None
+ self.num_cls_tokens = num_cls_tokens
+ self.kernel_size = kernel_size
+ self.pos_embed = nn.Parameter(
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
+ )
+
+ if self.num_cls_tokens > 0:
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
+ )
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style):
+ nn.init.normal_(self.pos_embed, std=0.01)
+
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ def tokenize_input_and_cls_pos(self, input, stem):
+ # tokens is of shape B x L x D
+ tokens = stem.norm_layer(stem.proj(input))
+ assert tokens.ndim == 3
+ assert tokens.shape[2] == self.embed_dim
+ B = tokens.shape[0]
+ if self.num_cls_tokens > 0:
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ tokens = torch.cat((class_tokens, tokens), dim=1)
+ if self.use_pos_embed:
+ tokens = tokens + self.pos_embed
+ return tokens
+
+ def forward(self, imu):
+ # Patchify
+ imu = imu.unfold(
+ -1,
+ self.kernel_size,
+ self.kernel_size,
+ ).permute(0, 2, 1, 3)
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
+
+ imu_tokens = self.tokenize_input_and_cls_pos(
+ imu,
+ self.imu_stem,
+ )
+
+ return_dict = {
+ "trunk": {
+ "tokens": imu_tokens,
+ },
+ "head": {},
+ }
+ return return_dict
\ No newline at end of file
diff --git a/imagebind/models/multimodal_projectors.py b/imagebind/models/multimodal_projectors.py
new file mode 100644
index 0000000..420a00c
--- /dev/null
+++ b/imagebind/models/multimodal_projectors.py
@@ -0,0 +1,45 @@
+from torch import nn, Tensor
+
+from typing import Union, Optional, Tuple
+
+
+class BaseProjector(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: Tensor) -> Tensor:
+ raise NotImplementedError
+
+
+class LinearProjector(BaseProjector):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+ self.fc = nn.Linear(in_dim, out_dim)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.fc(x)
+
+
+class AdapterProjector(BaseProjector):
+ def __init__(self, in_dim, mid_dim, out_dim):
+ super().__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(in_dim, mid_dim, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(mid_dim, out_dim, bias=False),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.fc(x)
+
+
+def create_projectors(dims):
+ if len(dims) == 0:
+ return nn.Identity()
+ elif len(dims) == 2:
+ return LinearProjector(*dims)
+ elif len(dims) == 3:
+ return AdapterProjector(*dims)
+ else:
+ raise NotImplementedError
diff --git a/imagebind/models/transformer.py b/imagebind/models/transformer.py
new file mode 100644
index 0000000..4bb3cfd
--- /dev/null
+++ b/imagebind/models/transformer.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Code modified from
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
+# https://github.com/facebookresearch/deit/blob/main/models.py
+# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
+
+
+import copy
+import fnmatch
+import logging
+from functools import partial
+from typing import Callable, List
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import DropPath, trunc_normal_
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version,
+ # can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class MultiheadAttention(nn.MultiheadAttention):
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+
+class ViTAttention(Attention):
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ assert attn_mask is None
+ return super().forward(x)
+
+
+class BlockWithMasking(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ attn_target: Callable,
+ mlp_ratio: int = 4,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ ffn_dropout_rate: float = 0.0,
+ drop_path: float = 0.0,
+ layer_scale_type: str = None,
+ layer_scale_init_value: float = 1e-4,
+ ):
+ super().__init__()
+
+ assert not isinstance(
+ attn_target, nn.Module
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
+ self.attn = attn_target()
+ if drop_path > 0.0:
+ self.drop_path = DropPath(drop_path)
+ else:
+ self.drop_path = nn.Identity()
+ self.norm_1 = norm_layer(dim)
+ mlp_hidden_dim = int(mlp_ratio * dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=ffn_dropout_rate,
+ )
+ self.norm_2 = norm_layer(dim)
+ self.layer_scale_type = layer_scale_type
+ if self.layer_scale_type is not None:
+ assert self.layer_scale_type in [
+ "per_channel",
+ "scalar",
+ ], f"Found Layer scale type {self.layer_scale_type}"
+ if self.layer_scale_type == "per_channel":
+ # one gamma value per channel
+ gamma_shape = [1, 1, dim]
+ elif self.layer_scale_type == "scalar":
+ # single gamma value for all channels
+ gamma_shape = [1, 1, 1]
+ # two gammas: for each part of the fwd in the encoder
+ self.layer_scale_gamma1 = nn.Parameter(
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
+ requires_grad=True,
+ )
+ self.layer_scale_gamma2 = nn.Parameter(
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
+ requires_grad=True,
+ )
+
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ if self.layer_scale_type is None:
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
+ else:
+ x = (
+ x
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
+ * self.layer_scale_gamma1
+ )
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
+ return x
+
+
+_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
+
+
+class SimpleTransformer(nn.Module):
+ def __init__(
+ self,
+ attn_target: Callable,
+ embed_dim: int,
+ num_blocks: int,
+ block: Callable = BlockWithMasking,
+ pre_transformer_layer: Callable = None,
+ post_transformer_layer: Callable = None,
+ drop_path_rate: float = 0.0,
+ drop_path_type: str = "progressive",
+ norm_layer: Callable = _LAYER_NORM,
+ mlp_ratio: int = 4,
+ ffn_dropout_rate: float = 0.0,
+ layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
+ layer_scale_init_value: float = 1e-4, # from cait; float
+ weight_init_style: str = "jax", # possible values jax or pytorch
+ ):
+ """
+ Simple Transformer with the following features
+ 1. Supports masked attention
+ 2. Supports DropPath
+ 3. Supports LayerScale
+ 4. Supports Dropout in Attention and FFN
+ 5. Makes few assumptions about the input except that it is a Tensor
+ """
+ super().__init__()
+ self.pre_transformer_layer = pre_transformer_layer
+ if drop_path_type == "progressive":
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
+ elif drop_path_type == "uniform":
+ dpr = [drop_path_rate for i in range(num_blocks)]
+ else:
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
+
+ self.blocks = nn.Sequential(
+ *[
+ block(
+ dim=embed_dim,
+ attn_target=attn_target,
+ mlp_ratio=mlp_ratio,
+ ffn_dropout_rate=ffn_dropout_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ layer_scale_type=layer_scale_type,
+ layer_scale_init_value=layer_scale_init_value,
+ )
+ for i in range(num_blocks)
+ ]
+ )
+ self.post_transformer_layer = post_transformer_layer
+ self.weight_init_style = weight_init_style
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ if self.weight_init_style == "jax":
+ # Based on MAE and official Jax ViT implementation
+ torch.nn.init.xavier_uniform_(m.weight)
+ elif self.weight_init_style == "pytorch":
+ # PyTorch ViT uses trunc_normal_
+ trunc_normal_(m.weight, std=0.02)
+
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (nn.LayerNorm)):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ attn_mask: torch.Tensor = None,
+ use_checkpoint: bool = False,
+ checkpoint_every_n: int = 1,
+ checkpoint_blk_ids: List[int] = None,
+ ):
+ """
+ Inputs
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
+ - attn: mask of shape L x L
+
+ Output
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
+ """
+ if self.pre_transformer_layer:
+ tokens = self.pre_transformer_layer(tokens)
+ if use_checkpoint and checkpoint_blk_ids is None:
+ checkpoint_blk_ids = [
+ blk_id
+ for blk_id in range(len(self.blocks))
+ if blk_id % checkpoint_every_n == 0
+ ]
+ if checkpoint_blk_ids:
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
+ for blk_id, blk in enumerate(self.blocks):
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
+ tokens = checkpoint.checkpoint(
+ blk, tokens, attn_mask, use_reentrant=False
+ )
+ else:
+ tokens = blk(tokens, attn_mask=attn_mask)
+ if self.post_transformer_layer:
+ tokens = self.post_transformer_layer(tokens)
+ return tokens
\ No newline at end of file
diff --git a/minigpt4/configs/models/bindgpt4.yaml b/minigpt4/configs/models/bindgpt4.yaml
new file mode 100644
index 0000000..e69de29
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
new file mode 100644
index 0000000..605d2f2
--- /dev/null
+++ b/minigpt4/models/bind_gpt4.py
@@ -0,0 +1,200 @@
+import random
+from typing import Dict, Tuple
+
+import torch
+from torch import Tensor
+from transformers import LlamaTokenizer
+
+from imagebind.models.image_bind import imagebind_huge, ImageBindJoiner, ModalityType
+from minigpt4.common.registry import registry
+from minigpt4.models.blip2 import BaseModel
+from minigpt4.models.modeling_llama import LlamaForCausalLM
+
+
+@registry.register_model("bind_gpt4")
+class BindGPT4(BaseModel):
+ """
+ ImageBind GPT-LLAMA model.
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain_vicuna": "configs/models/bindgpt4.yaml",
+ }
+
+ def __init__(
+ self,
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+ freeze_imagebind=True,
+ freeze_qformer=False,
+ num_query_token=32,
+ llama_model="",
+ prompt_path="",
+ prompt_template="",
+ max_txt_len=32,
+ end_sym='\n',
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0 # the device of 8bit model should be set when loading and cannot be changed anymore.
+ ):
+ super().__init__()
+ assert not low_resource, "Low Resource Mode is Currently Unavailable."
+
+ self.low_resource = low_resource
+
+ print('Loading ImageBind')
+ self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
+ print('Loading ImageBind Done')
+
+ print('Loading Q-Former and Adapter/Projector')
+ self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
+ vision_qformer_frozen=freeze_qformer
+ # vision_qformer_model=q_former_model,
+ # vision_pre_dims=(1280, 1408)
+ )
+ print('Loading Q-Former and Adapter/Projector Done')
+
+ print('Loading LLAMA')
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
+
+ self.llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model,
+ torch_dtype=torch.float16,
+ )
+
+ for name, param in self.llama_model.named_parameters():
+ param.requires_grad = False
+ print('Loading LLAMA Done')
+
+ self.max_txt_len = max_txt_len
+ self.end_sym = end_sym
+
+ print("Preparing Prompts")
+ if prompt_path:
+ with open(prompt_path, 'r') as f:
+ raw_prompts = f.read().splitlines()
+ self.prompt_list = [prompt_template.format(p) for p in raw_prompts]
+ print('Load {} training prompts'.format(len(self.prompt_list)))
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
+ else:
+ self.prompt_list = []
+ print("Preparing Prompts Done")
+
+ def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ imagebind_outputs = self.multimodal_encoder(inputs)
+ llama_inputs = self.multimodal_joiner(imagebind_outputs)
+ return llama_inputs
+
+ def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
+ # TODO: Accept More Modalities.
+ input_embeds = inputs[modality_name]
+ attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
+ if prompt:
+ batch_size = input_embeds.shape[0]
+ p_before, p_after = prompt.split('<{}Here>'.format(modality_name))
+ p_before_tokens = self.llama_tokenizer(
+ p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
+ p_after_tokens = self.llama_tokenizer(
+ p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
+ p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
+ p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
+ wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
+ wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
+ return wrapped_input_embeds, wrapped_atts_input
+ else:
+ return input_embeds, attns_input
+
+ def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ """
+ TODO: More Modalities.
+ Only accept image inputs here.
+ Other modalities will conflict with the pre-defined prompt and wrapping strategy.
+ """
+ embeds = self.encode_inputs(inputs)
+ assert "Vision" in embeds
+ prompt = random.choice(self.prompt_list)
+ img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
+
+ # NOTE: No modifications from the next line to the end. Except for the autocast part.
+
+ self.llama_tokenizer.padding_side = "right"
+
+ text = [t + self.end_sym for t in inputs["text_input"]]
+
+ to_regress_tokens = self.llama_tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_txt_len,
+ add_special_tokens=False
+ ).to(img_embeds.device)
+
+ targets = to_regress_tokens.input_ids.masked_fill(
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
+ )
+
+ empty_targets = (
+ torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
+ dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
+ )
+ targets = torch.cat([empty_targets, targets], dim=1)
+
+ batch_size = img_embeds.shape[0]
+ bos = torch.ones([batch_size, 1],
+ dtype=to_regress_tokens.input_ids.dtype,
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
+ atts_bos = atts_img[:, :1]
+
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
+
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ )
+ loss = outputs.loss
+
+ return {"loss": loss}
+
+ @classmethod
+ def from_config(cls, cfg):
+ q_former_model = cfg.get("q_former_model",
+ "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
+ num_query_token = cfg.get("num_query_token")
+ llama_model = cfg.get("llama_model")
+
+ freeze_imagebind = cfg.get("freeze_imagebind", True)
+ freeze_qformer = cfg.get("freeze_qformer", True)
+ low_resource = cfg.get("low_resource", False)
+ device_8bit = cfg.get("device_8bit", 0)
+
+ prompt_path = cfg.get("prompt_path", "")
+ prompt_template = cfg.get("prompt_template", "")
+ max_txt_len = cfg.get("max_txt_len", 32)
+ end_sym = cfg.get("end_sym", '\n')
+
+ model = cls(
+ q_former_model=q_former_model,
+ freeze_imagebind=freeze_imagebind,
+ freeze_qformer=freeze_qformer,
+ num_query_token=num_query_token,
+ llama_model=llama_model,
+ prompt_path=prompt_path,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ end_sym=end_sym,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load ImageBind-LLM Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(ckpt['model'], strict=False)
+
+ return model
diff --git a/prompts/alignment.txt b/prompts/alignment.txt
index 38ae75a..90ae57b 100644
--- a/prompts/alignment.txt
+++ b/prompts/alignment.txt
@@ -1,4 +1,4 @@
-
Describe this image in detail.
-
Take a look at this image and describe what you notice.
-
Please provide a detailed description of the picture.
-
Could you describe the contents of this image for me?
\ No newline at end of file
+ Describe this image in detail.
+ Take a look at this image and describe what you notice.
+ Please provide a detailed description of the picture.
+ Could you describe the contents of this image for me?
\ No newline at end of file