mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 02:50:47 +00:00
141 lines
3.9 KiB
Python
141 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class Normalize(nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
|
|
|
|
|
|
class LearnableLogitScaling(nn.Module):
|
|
def __init__(
|
|
self,
|
|
logit_scale_init: float = 1 / 0.07,
|
|
learnable: bool = True,
|
|
max_logit_scale: float = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
self.max_logit_scale = max_logit_scale
|
|
self.logit_scale_init = logit_scale_init
|
|
self.learnable = learnable
|
|
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
|
|
if learnable:
|
|
self.log_logit_scale = nn.Parameter(log_logit_scale)
|
|
else:
|
|
self.register_buffer("log_logit_scale", log_logit_scale)
|
|
|
|
def forward(self, x):
|
|
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
|
|
|
|
def extra_repr(self):
|
|
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
|
|
return st
|
|
|
|
|
|
class EinOpsRearrange(nn.Module):
|
|
def __init__(self, rearrange_expr: str, **kwargs) -> None:
|
|
super().__init__()
|
|
self.rearrange_expr = rearrange_expr
|
|
self.kwargs = kwargs
|
|
|
|
def forward(self, x):
|
|
assert isinstance(x, torch.Tensor)
|
|
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
|
|
|
|
|
|
class VerboseNNModule(nn.Module):
|
|
"""
|
|
Wrapper around nn.Module that prints registered buffers and parameter names.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
|
|
st = (
|
|
"("
|
|
+ name
|
|
+ "): "
|
|
+ "tensor("
|
|
+ str(tuple(tensor[1].shape))
|
|
+ ", requires_grad="
|
|
+ str(tensor[1].requires_grad)
|
|
+ ")\n"
|
|
)
|
|
return st
|
|
|
|
def extra_repr(self) -> str:
|
|
named_modules = set()
|
|
for p in self.named_modules():
|
|
named_modules.update([p[0]])
|
|
named_modules = list(named_modules)
|
|
|
|
string_repr = ""
|
|
for p in self.named_parameters():
|
|
name = p[0].split(".")[0]
|
|
if name not in named_modules:
|
|
string_repr += self.get_readable_tensor_repr(name, p)
|
|
|
|
for p in self.named_buffers():
|
|
name = p[0].split(".")[0]
|
|
string_repr += self.get_readable_tensor_repr(name, p)
|
|
|
|
return string_repr
|
|
|
|
|
|
def cast_if_src_dtype(
|
|
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
|
|
):
|
|
updated = False
|
|
if tensor.dtype == src_dtype:
|
|
tensor = tensor.to(dtype=tgt_dtype)
|
|
updated = True
|
|
return tensor, updated
|
|
|
|
|
|
class QuickGELU(nn.Module):
|
|
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
|
|
def forward(self, x: torch.Tensor):
|
|
return x * torch.sigmoid(1.702 * x)
|
|
|
|
|
|
class SelectElement(nn.Module):
|
|
def __init__(self, index) -> None:
|
|
super().__init__()
|
|
self.index = index
|
|
|
|
def forward(self, x):
|
|
assert x.ndim >= 3
|
|
return x[:, self.index, ...]
|
|
|
|
|
|
class SelectEOSAndProject(nn.Module):
|
|
"""
|
|
Text Pooling used in OpenCLIP
|
|
"""
|
|
|
|
def __init__(self, proj: nn.Module) -> None:
|
|
super().__init__()
|
|
self.proj = proj
|
|
|
|
def forward(self, x, seq_len):
|
|
assert x.ndim == 3
|
|
# x is of shape B x L x D
|
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
x = x[torch.arange(x.shape[0]), seq_len]
|
|
x = self.proj(x)
|
|
return x |