mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
191 lines
6.8 KiB
Python
191 lines
6.8 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# -*- coding: utf-8 -*-
|
||
|
# @Author : Xinhao Mei @CVSSP, University of Surrey
|
||
|
# @E-mail : x.mei@surrey.ac.uk
|
||
|
|
||
|
"""
|
||
|
Implemenation of SpecAugment++,
|
||
|
Adapated from Qiuqiang Kong's trochlibrosa:
|
||
|
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class DropStripes:
|
||
|
|
||
|
def __init__(self, dim, drop_width, stripes_num):
|
||
|
""" Drop stripes.
|
||
|
args:
|
||
|
dim: int, dimension along which to drop
|
||
|
drop_width: int, maximum width of stripes to drop
|
||
|
stripes_num: int, how many stripes to drop
|
||
|
"""
|
||
|
super(DropStripes, self).__init__()
|
||
|
|
||
|
assert dim in [2, 3] # dim 2: time; dim 3: frequency
|
||
|
|
||
|
self.dim = dim
|
||
|
self.drop_width = drop_width
|
||
|
self.stripes_num = stripes_num
|
||
|
|
||
|
def __call__(self, input):
|
||
|
"""input: (batch_size, channels, time_steps, freq_bins)"""
|
||
|
|
||
|
assert input.ndimension() == 4
|
||
|
batch_size = input.shape[0]
|
||
|
total_width = input.shape[self.dim]
|
||
|
|
||
|
for n in range(batch_size):
|
||
|
self.transform_slice(input[n], total_width)
|
||
|
|
||
|
return input
|
||
|
|
||
|
def transform_slice(self, e, total_width):
|
||
|
""" e: (channels, time_steps, freq_bins)"""
|
||
|
|
||
|
for _ in range(self.stripes_num):
|
||
|
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
|
||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||
|
|
||
|
if self.dim == 2:
|
||
|
e[:, bgn: bgn + distance, :] = 0
|
||
|
elif self.dim == 3:
|
||
|
e[:, :, bgn: bgn + distance] = 0
|
||
|
|
||
|
|
||
|
class MixStripes:
|
||
|
|
||
|
def __init__(self, dim, mix_width, stripes_num):
|
||
|
""" Mix stripes
|
||
|
args:
|
||
|
dim: int, dimension along which to mix
|
||
|
mix_width: int, maximum width of stripes to mix
|
||
|
stripes_num: int, how many stripes to mix
|
||
|
"""
|
||
|
|
||
|
super(MixStripes, self).__init__()
|
||
|
|
||
|
assert dim in [2, 3]
|
||
|
|
||
|
self.dim = dim
|
||
|
self.mix_width = mix_width
|
||
|
self.stripes_num = stripes_num
|
||
|
|
||
|
def __call__(self, input):
|
||
|
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||
|
|
||
|
assert input.ndimension() == 4
|
||
|
|
||
|
batch_size = input.shape[0]
|
||
|
total_width = input.shape[self.dim]
|
||
|
|
||
|
rand_sample = input[torch.randperm(batch_size)]
|
||
|
for i in range(batch_size):
|
||
|
self.transform_slice(input[i], rand_sample[i], total_width)
|
||
|
return input
|
||
|
|
||
|
def transform_slice(self, input, random_sample, total_width):
|
||
|
|
||
|
for _ in range(self.stripes_num):
|
||
|
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
|
||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||
|
|
||
|
if self.dim == 2:
|
||
|
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
|
||
|
0.5 * random_sample[:, bgn: bgn + distance, :]
|
||
|
elif self.dim == 3:
|
||
|
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
|
||
|
0.5 * random_sample[:, :, bgn: bgn + distance]
|
||
|
|
||
|
|
||
|
class CutStripes:
|
||
|
|
||
|
def __init__(self, dim, cut_width, stripes_num):
|
||
|
""" Cutting stripes with another randomly selected sample in mini-batch.
|
||
|
args:
|
||
|
dim: int, dimension along which to cut
|
||
|
cut_width: int, maximum width of stripes to cut
|
||
|
stripes_num: int, how many stripes to cut
|
||
|
"""
|
||
|
|
||
|
super(CutStripes, self).__init__()
|
||
|
|
||
|
assert dim in [2, 3]
|
||
|
|
||
|
self.dim = dim
|
||
|
self.cut_width = cut_width
|
||
|
self.stripes_num = stripes_num
|
||
|
|
||
|
def __call__(self, input):
|
||
|
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||
|
|
||
|
assert input.ndimension() == 4
|
||
|
|
||
|
batch_size = input.shape[0]
|
||
|
total_width = input.shape[self.dim]
|
||
|
|
||
|
rand_sample = input[torch.randperm(batch_size)]
|
||
|
for i in range(batch_size):
|
||
|
self.transform_slice(input[i], rand_sample[i], total_width)
|
||
|
return input
|
||
|
|
||
|
def transform_slice(self, input, random_sample, total_width):
|
||
|
|
||
|
for _ in range(self.stripes_num):
|
||
|
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
|
||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||
|
|
||
|
if self.dim == 2:
|
||
|
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
|
||
|
elif self.dim == 3:
|
||
|
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
|
||
|
|
||
|
|
||
|
class SpecAugmentation:
|
||
|
|
||
|
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
|
||
|
mask_type='mixture'):
|
||
|
"""Spec augmetation and SpecAugment++.
|
||
|
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
|
||
|
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
|
||
|
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
|
||
|
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
|
||
|
Data Augmentation Method for Acoustic Scene Classification. arXiv
|
||
|
preprint arXiv:2103.16858.
|
||
|
|
||
|
Args:
|
||
|
time_drop_width: int
|
||
|
time_stripes_num: int
|
||
|
freq_drop_width: int
|
||
|
freq_stripes_num: int
|
||
|
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
|
||
|
"""
|
||
|
|
||
|
super(SpecAugmentation, self).__init__()
|
||
|
|
||
|
if mask_type == 'zero_value':
|
||
|
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
|
||
|
stripes_num=time_stripes_num)
|
||
|
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
|
||
|
stripes_num=freq_stripes_num)
|
||
|
elif mask_type == 'mixture':
|
||
|
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
|
||
|
stripes_num=time_stripes_num)
|
||
|
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
|
||
|
stripes_num=freq_stripes_num)
|
||
|
elif mask_type == 'cutting':
|
||
|
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
|
||
|
stripes_num=time_stripes_num)
|
||
|
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
|
||
|
stripes_num=freq_stripes_num)
|
||
|
else:
|
||
|
raise NameError('No such mask type in SpecAugment++')
|
||
|
|
||
|
def __call__(self, inputs):
|
||
|
# x should be in size [batch_size, channel, time_steps, freq_bins]
|
||
|
x = self.time_augmentator(inputs)
|
||
|
x = self.freq_augmentator(x)
|
||
|
return x
|