mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
modularize minigpt4 code
This commit is contained in:
parent
90b7b00268
commit
045a1d0602
@ -5,19 +5,26 @@
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BertTokenizer
|
||||
from transformers import BertTokenizer, LlamaTokenizer
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
prepare_model_for_int8_training,
|
||||
)
|
||||
|
||||
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||
from minigpt4.common.utils import get_abs_path, is_url
|
||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
@ -121,12 +128,6 @@ class BaseModel(nn.Module):
|
||||
else:
|
||||
return tot
|
||||
|
||||
@classmethod
|
||||
def init_tokenizer(cls):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
return tokenizer
|
||||
|
||||
def maybe_autocast(self, dtype=torch.float16):
|
||||
# if on cpu, don't use autocast
|
||||
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||
@ -137,33 +138,74 @@ class BaseModel(nn.Module):
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@classmethod
|
||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||
encoder_config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
encoder_config.add_cross_attention = True
|
||||
encoder_config.cross_attention_freq = cross_attention_freq
|
||||
encoder_config.query_length = num_query_token
|
||||
Qformer = BertLMHeadModel(config=encoder_config)
|
||||
query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||
)
|
||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||
return Qformer, query_tokens
|
||||
|
||||
@classmethod
|
||||
def init_vision_encoder(
|
||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
|
||||
):
|
||||
logging.info('Loading VIT')
|
||||
|
||||
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||
if not freeze:
|
||||
precision = "fp32" # fp16 is not for training
|
||||
|
||||
visual_encoder = create_eva_vit_g(
|
||||
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
)
|
||||
|
||||
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||
|
||||
if freeze:
|
||||
for name, param in visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
visual_encoder = visual_encoder.eval()
|
||||
visual_encoder.train = disabled_train
|
||||
for name, param in ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
ln_vision = ln_vision.eval()
|
||||
ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
|
||||
logging.info('Loading VIT Done')
|
||||
return visual_encoder, ln_vision
|
||||
|
||||
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
|
||||
**lora_kargs):
|
||||
logging.info('Loading LLAMA')
|
||||
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
|
||||
llama_tokenizer.pad_token = "$$"
|
||||
|
||||
if low_resource:
|
||||
llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map={'': low_res_device}
|
||||
)
|
||||
else:
|
||||
llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if lora_r > 0:
|
||||
llama_model = prepare_model_for_int8_training(llama_model)
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
**lora_kargs
|
||||
)
|
||||
llama_model = get_peft_model(llama_model, loraconfig)
|
||||
|
||||
llama_model.print_trainable_parameters()
|
||||
|
||||
else:
|
||||
for name, param in llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
logging.info('Loading LLAMA Done')
|
||||
return llama_model, llama_tokenizer
|
||||
|
||||
|
||||
def load_from_pretrained(self, url_or_filename):
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
@ -185,136 +227,6 @@ class BaseModel(nn.Module):
|
||||
return msg
|
||||
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""
|
||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_features(self, samples, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
|
||||
class SharedQueueMixin:
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
||||
# gather keys before updating queue
|
||||
image_feats = concat_all_gather(image_feat)
|
||||
text_feats = concat_all_gather(text_feat)
|
||||
|
||||
batch_size = image_feats.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.queue_size % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
||||
|
||||
if idxs is not None:
|
||||
idxs = concat_all_gather(idxs)
|
||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
||||
|
||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
|
||||
class MomentumDistilationMixin:
|
||||
@torch.no_grad()
|
||||
def copy_params(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data.copy_(param.data) # initialize
|
||||
param_m.requires_grad = False # not update by gradient
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data = param_m.data * self.momentum + param.data * (
|
||||
1.0 - self.momentum
|
||||
)
|
||||
|
||||
|
||||
class GatherLayer(torch.autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [
|
||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
torch.distributed.all_reduce(all_gradients)
|
||||
return all_gradients[torch.distributed.get_rank()]
|
||||
|
||||
|
||||
def all_gather_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = torch.distributed.get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
|
||||
# tensor_all = GatherLayer.apply(tensors)
|
||||
tensor_all = GatherLayer.apply(tensors)
|
||||
|
||||
return torch.cat(tensor_all, dim=0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
# if use distributed training
|
||||
if not is_dist_avail_and_initialized():
|
||||
return tensor
|
||||
|
||||
tensors_gather = [
|
||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def tile(x, dim, n_tile):
|
||||
init_dim = x.size(dim)
|
||||
repeat_idx = [1] * x.dim()
|
||||
repeat_idx[dim] = n_tile
|
||||
x = x.repeat(*(repeat_idx))
|
||||
order_index = torch.LongTensor(
|
||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||
)
|
||||
return torch.index_select(x, dim, order_index.to(x.device))
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel, disabled_train
|
||||
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
@ -31,6 +32,99 @@ class MiniGPT4(MiniGPTBase):
|
||||
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
has_qformer=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
):
|
||||
super().__init__(
|
||||
vit_model=vit_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
llama_model=llama_model,
|
||||
max_txt_len=max_txt_len,
|
||||
end_sym=end_sym,
|
||||
low_resource=low_resource,
|
||||
device_8bit=device_8bit,
|
||||
)
|
||||
|
||||
self.has_qformer = has_qformer
|
||||
if self.has_qformer:
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features, freeze_qformer
|
||||
)
|
||||
self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
|
||||
|
||||
img_f_dim = self.Qformer.config.hidden_size
|
||||
print('Loading Q-Former Done')
|
||||
else:
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
print('Do not use Q-Former here.')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
@classmethod
|
||||
def init_Qformer(cls, num_query_token, vision_width, freeze):
|
||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||
encoder_config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
encoder_config.add_cross_attention = True
|
||||
encoder_config.cross_attention_freq = 2
|
||||
encoder_config.query_length = num_query_token
|
||||
Qformer = BertLMHeadModel(config=encoder_config)
|
||||
query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||
)
|
||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||
|
||||
Qformer.cls = None
|
||||
Qformer.bert.embeddings.word_embeddings = None
|
||||
Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
|
||||
if freeze:
|
||||
for name, param in Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
Qformer = Qformer.eval()
|
||||
Qformer.train = disabled_train
|
||||
query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
|
||||
return Qformer, query_tokens
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
|
||||
@ -82,9 +176,6 @@ class MiniGPT4(MiniGPTBase):
|
||||
max_txt_len = cfg.get("max_txt_len", 32)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
lora_r = cfg.get("lora_r", 0)
|
||||
lora_alpha = cfg.get("lora_alpha", 32)
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
q_former_model=q_former_model,
|
||||
@ -103,8 +194,6 @@ class MiniGPT4(MiniGPTBase):
|
||||
end_sym=end_sym,
|
||||
low_resource=low_resource,
|
||||
device_8bit=device_8bit,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
|
@ -13,9 +13,7 @@ from transformers import LlamaTokenizer
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -27,131 +25,41 @@ class MiniGPTBase(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
has_qformer=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
lora_r=0,
|
||||
lora_r=0, # lora_r means lora is not used
|
||||
lora_target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
self.llama_model, self.llama_tokenizer = self.init_llm(
|
||||
llama_model_path=llama_model,
|
||||
low_resource=low_resource,
|
||||
low_res_device=device_8bit,
|
||||
lora_r=lora_r,
|
||||
lora_target_modules=lora_target_modules,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
|
||||
)
|
||||
if freeze_vit:
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print('Loading VIT Done')
|
||||
|
||||
self.has_qformer = has_qformer
|
||||
if self.has_qformer:
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
|
||||
img_f_dim = self.Qformer.config.hidden_size
|
||||
print('Loading Q-Former Done')
|
||||
else:
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
print('Do not use Q-Former here.')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
if self.low_resource:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map={'': device_8bit}
|
||||
)
|
||||
else:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if lora_r > 0:
|
||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||
|
||||
# if ckpt_path:
|
||||
# print('load the llm under lora')
|
||||
# ckpt = torch.load(ckpt_path)
|
||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||
self.llama_model.print_trainable_parameters()
|
||||
|
||||
else:
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
self.prompt_list = []
|
||||
|
||||
def vit_to_cpu(self):
|
||||
self.ln_vision.to("cpu")
|
||||
|
Loading…
Reference in New Issue
Block a user