mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 02:50:47 +00:00
46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
from torch import nn, Tensor
|
|
|
|
from typing import Union, Optional, Tuple
|
|
|
|
|
|
class BaseProjector(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class LinearProjector(BaseProjector):
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.fc = nn.Linear(in_dim, out_dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.fc(x)
|
|
|
|
|
|
class AdapterProjector(BaseProjector):
|
|
def __init__(self, in_dim, mid_dim, out_dim):
|
|
super().__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_dim, mid_dim, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(mid_dim, out_dim, bias=False),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.fc(x)
|
|
|
|
|
|
def create_projectors(dims):
|
|
if len(dims) == 0:
|
|
return nn.Identity()
|
|
elif len(dims) == 2:
|
|
return LinearProjector(*dims)
|
|
elif len(dims) == 3:
|
|
return AdapterProjector(*dims)
|
|
else:
|
|
raise NotImplementedError
|