mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
- Add audio datasets - Add audio processors - Add audio support in bindgpt - Add audio training config --------- Co-authored-by: bingyikang <bingyikang@bytedance.com> Co-authored-by: zhaoyang <913556700@qq.com>
347 lines
12 KiB
Python
347 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchaudio
|
|
import logging
|
|
|
|
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
|
|
from PIL import Image
|
|
from pytorchvideo import transforms as pv_transforms
|
|
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
|
|
from pytorchvideo.data.encoded_video import EncodedVideo
|
|
|
|
from torchvision import transforms
|
|
from torchvision.transforms._transforms_video import NormalizeVideo
|
|
|
|
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
|
|
|
|
BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
|
|
|
|
|
|
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
|
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
|
|
waveform -= waveform.mean()
|
|
fbank = torchaudio.compliance.kaldi.fbank(
|
|
waveform,
|
|
htk_compat=True,
|
|
sample_frequency=sample_rate,
|
|
use_energy=False,
|
|
window_type="hanning",
|
|
num_mel_bins=num_mel_bins,
|
|
dither=0.0,
|
|
frame_length=25,
|
|
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
|
|
)
|
|
# Convert to [mel_bins, num_frames] shape
|
|
fbank = fbank.transpose(0, 1)
|
|
# Pad to target_length
|
|
n_frames = fbank.size(1)
|
|
p = target_length - n_frames
|
|
# if p is too large (say >20%), flash a warning
|
|
if abs(p) / n_frames > 0.2:
|
|
logging.warning(
|
|
"Large gap between audio n_frames(%d) and "
|
|
"target_length (%d). Is the audio_target_length "
|
|
"setting correct?",
|
|
n_frames,
|
|
target_length,
|
|
)
|
|
# cut and pad
|
|
if p > 0:
|
|
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
|
elif p < 0:
|
|
fbank = fbank[:, 0:target_length]
|
|
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image
|
|
fbank = fbank.unsqueeze(0)
|
|
return fbank
|
|
|
|
|
|
def load_and_transform_vision_data(image_paths, device):
|
|
if image_paths is None:
|
|
return None
|
|
|
|
image_ouputs = []
|
|
for image_path in image_paths:
|
|
data_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
224, interpolation=transforms.InterpolationMode.BICUBIC
|
|
),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
|
std=(0.26862954, 0.26130258, 0.27577711),
|
|
),
|
|
]
|
|
)
|
|
with open(image_path, "rb") as fopen:
|
|
image = Image.open(fopen).convert("RGB")
|
|
|
|
image = data_transform(image).to(device)
|
|
image_ouputs.append(image)
|
|
return torch.stack(image_ouputs, dim=0)
|
|
|
|
|
|
def load_and_transform_text(text, device):
|
|
if text is None:
|
|
return None
|
|
tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
|
|
tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
|
|
tokens = torch.cat(tokens, dim=0)
|
|
return tokens
|
|
|
|
|
|
def load_and_transform_audio_data(
|
|
audio_paths,
|
|
device,
|
|
num_mel_bins=128,
|
|
target_length=204,
|
|
sample_rate=16000,
|
|
clip_duration=2,
|
|
clips_per_video=3,
|
|
mean=-4.268,
|
|
std=9.138,
|
|
):
|
|
if audio_paths is None:
|
|
return None
|
|
|
|
audio_outputs = []
|
|
clip_sampler = ConstantClipsPerVideoSampler(
|
|
clip_duration=clip_duration, clips_per_video=clips_per_video
|
|
)
|
|
|
|
for audio_path in audio_paths:
|
|
waveform, sr = torchaudio.load(audio_path)
|
|
if sample_rate != sr:
|
|
waveform = torchaudio.functional.resample(
|
|
waveform, orig_freq=sr, new_freq=sample_rate
|
|
)
|
|
all_clips_timepoints = get_constant_clip_timepoints(
|
|
clip_sampler, waveform.size(1) / sample_rate
|
|
)
|
|
all_clips = []
|
|
for clip_timepoints in all_clips_timepoints:
|
|
waveform_clip = waveform[
|
|
:,
|
|
int(clip_timepoints[0] * sample_rate): int(
|
|
clip_timepoints[1] * sample_rate
|
|
),
|
|
]
|
|
waveform_melspec = waveform2melspec(
|
|
waveform_clip, sample_rate, num_mel_bins, target_length
|
|
)
|
|
all_clips.append(waveform_melspec)
|
|
|
|
normalize = transforms.Normalize(mean=mean, std=std)
|
|
all_clips = [normalize(ac).to(device) for ac in all_clips]
|
|
|
|
all_clips = torch.stack(all_clips, dim=0)
|
|
audio_outputs.append(all_clips)
|
|
|
|
return torch.stack(audio_outputs, dim=0)
|
|
|
|
|
|
def get_constant_clip_timepoints(clip_sampler, duration):
|
|
assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!"
|
|
# Read out all clips in this video
|
|
all_clips_timepoints = []
|
|
is_last_clip = False
|
|
end = 0.0
|
|
while not is_last_clip:
|
|
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
|
all_clips_timepoints.append((start, end))
|
|
return all_clips_timepoints
|
|
|
|
|
|
def get_random_clip_timepoints(clip_sampler, duration):
|
|
assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!"
|
|
starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None)
|
|
all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0])
|
|
return all_clips_timepoints
|
|
|
|
|
|
def crop_boxes(boxes, x_offset, y_offset):
|
|
"""
|
|
Perform crop on the bounding boxes given the offsets.
|
|
Args:
|
|
boxes (ndarray or None): bounding boxes to perform crop. The dimension
|
|
is `num boxes` x 4.
|
|
x_offset (int): cropping offset in the x axis.
|
|
y_offset (int): cropping offset in the y axis.
|
|
Returns:
|
|
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
|
`num boxes` x 4.
|
|
"""
|
|
cropped_boxes = boxes.copy()
|
|
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
|
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
|
|
|
return cropped_boxes
|
|
|
|
|
|
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
|
"""
|
|
Perform uniform spatial sampling on the images and corresponding boxes.
|
|
Args:
|
|
images (tensor): images to perform uniform crop. The dimension is
|
|
`num frames` x `channel` x `height` x `width`.
|
|
size (int): size of height and weight to crop the images.
|
|
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
|
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
|
crop if height is larger than width.
|
|
boxes (ndarray or None): optional. Corresponding boxes to images.
|
|
Dimension is `num boxes` x 4.
|
|
scale_size (int): optinal. If not None, resize the images to scale_size before
|
|
performing any crop.
|
|
Returns:
|
|
cropped (tensor): images with dimension of
|
|
`num frames` x `channel` x `size` x `size`.
|
|
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
|
`num boxes` x 4.
|
|
"""
|
|
assert spatial_idx in [0, 1, 2]
|
|
ndim = len(images.shape)
|
|
if ndim == 3:
|
|
images = images.unsqueeze(0)
|
|
height = images.shape[2]
|
|
width = images.shape[3]
|
|
|
|
if scale_size is not None:
|
|
if width <= height:
|
|
width, height = scale_size, int(height / width * scale_size)
|
|
else:
|
|
width, height = int(width / height * scale_size), scale_size
|
|
images = torch.nn.functional.interpolate(
|
|
images,
|
|
size=(height, width),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
y_offset = int(math.ceil((height - size) / 2))
|
|
x_offset = int(math.ceil((width - size) / 2))
|
|
|
|
if height > width:
|
|
if spatial_idx == 0:
|
|
y_offset = 0
|
|
elif spatial_idx == 2:
|
|
y_offset = height - size
|
|
else:
|
|
if spatial_idx == 0:
|
|
x_offset = 0
|
|
elif spatial_idx == 2:
|
|
x_offset = width - size
|
|
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
|
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
|
if ndim == 3:
|
|
cropped = cropped.squeeze(0)
|
|
return cropped, cropped_boxes
|
|
|
|
|
|
class SpatialCrop(nn.Module):
|
|
"""
|
|
Convert the video into 3 smaller clips spatially. Must be used after the
|
|
temporal crops to get spatial crops, and should be used with
|
|
-2 in the spatial crop at the slowfast augmentation stage (so full
|
|
frames are passed in here). Will return a larger list with the
|
|
3x spatial crops as well.
|
|
"""
|
|
|
|
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
|
super().__init__()
|
|
self.crop_size = crop_size
|
|
if num_crops == 3:
|
|
self.crops_to_ext = [0, 1, 2]
|
|
self.flipped_crops_to_ext = []
|
|
elif num_crops == 1:
|
|
self.crops_to_ext = [1]
|
|
self.flipped_crops_to_ext = []
|
|
else:
|
|
raise NotImplementedError("Nothing else supported yet")
|
|
|
|
def forward(self, videos):
|
|
"""
|
|
Args:
|
|
videos: A list of C, T, H, W videos.
|
|
Returns:
|
|
videos: A list with 3x the number of elements. Each video converted
|
|
to C, T, H', W' by spatial cropping.
|
|
"""
|
|
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
|
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
|
res = []
|
|
for video in videos:
|
|
for spatial_idx in self.crops_to_ext:
|
|
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
|
if not self.flipped_crops_to_ext:
|
|
continue
|
|
flipped_video = transforms.functional.hflip(video)
|
|
for spatial_idx in self.flipped_crops_to_ext:
|
|
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
|
return res
|
|
|
|
|
|
def load_and_transform_video_data(
|
|
video_paths,
|
|
device,
|
|
clip_duration=2,
|
|
clips_per_video=5,
|
|
sample_rate=16000,
|
|
):
|
|
if video_paths is None:
|
|
return None
|
|
|
|
video_outputs = []
|
|
video_transform = transforms.Compose(
|
|
[
|
|
pv_transforms.ShortSideScale(224),
|
|
NormalizeVideo(
|
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
|
std=(0.26862954, 0.26130258, 0.27577711),
|
|
),
|
|
]
|
|
)
|
|
|
|
clip_sampler = ConstantClipsPerVideoSampler(
|
|
clip_duration=clip_duration, clips_per_video=clips_per_video
|
|
)
|
|
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
|
|
|
|
for video_path in video_paths:
|
|
video = EncodedVideo.from_path(
|
|
video_path,
|
|
decoder="decord",
|
|
decode_audio=False,
|
|
**{"sample_rate": sample_rate},
|
|
)
|
|
|
|
all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration)
|
|
|
|
all_video = []
|
|
for clip_timepoints in all_clips_timepoints:
|
|
# Read the clip, get frames
|
|
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
|
|
if clip is None:
|
|
raise ValueError("No clip found")
|
|
video_clip = frame_sampler(clip["video"])
|
|
video_clip = video_clip / 255.0 # since this is float, need 0-1
|
|
|
|
all_video.append(video_clip)
|
|
|
|
all_video = [video_transform(clip) for clip in all_video]
|
|
all_video = SpatialCrop(224, num_crops=3)(all_video)
|
|
|
|
all_video = torch.stack(all_video, dim=0)
|
|
video_outputs.append(all_video)
|
|
|
|
return torch.stack(video_outputs, dim=0).to(device)
|