mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-08 03:50:46 +00:00
modularize minigpt4 code
This commit is contained in:
parent
90b7b00268
commit
045a1d0602
@ -5,19 +5,26 @@
|
|||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer, LlamaTokenizer
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
)
|
||||||
|
|
||||||
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||||
from minigpt4.common.utils import get_abs_path, is_url
|
from minigpt4.common.utils import get_abs_path, is_url
|
||||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
|
||||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
@ -121,12 +128,6 @@ class BaseModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return tot
|
return tot
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_tokenizer(cls):
|
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def maybe_autocast(self, dtype=torch.float16):
|
def maybe_autocast(self, dtype=torch.float16):
|
||||||
# if on cpu, don't use autocast
|
# if on cpu, don't use autocast
|
||||||
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||||
@ -137,33 +138,74 @@ class BaseModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
|
||||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
|
||||||
encoder_config.encoder_width = vision_width
|
|
||||||
# insert cross-attention layer every other block
|
|
||||||
encoder_config.add_cross_attention = True
|
|
||||||
encoder_config.cross_attention_freq = cross_attention_freq
|
|
||||||
encoder_config.query_length = num_query_token
|
|
||||||
Qformer = BertLMHeadModel(config=encoder_config)
|
|
||||||
query_tokens = nn.Parameter(
|
|
||||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
|
||||||
)
|
|
||||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
|
||||||
return Qformer, query_tokens
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_vision_encoder(
|
def init_vision_encoder(
|
||||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
|
||||||
):
|
):
|
||||||
|
logging.info('Loading VIT')
|
||||||
|
|
||||||
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||||
|
if not freeze:
|
||||||
|
precision = "fp32" # fp16 is not for training
|
||||||
|
|
||||||
visual_encoder = create_eva_vit_g(
|
visual_encoder = create_eva_vit_g(
|
||||||
img_size, drop_path_rate, use_grad_checkpoint, precision
|
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||||
)
|
)
|
||||||
|
|
||||||
ln_vision = LayerNorm(visual_encoder.num_features)
|
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
for name, param in visual_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
visual_encoder = visual_encoder.eval()
|
||||||
|
visual_encoder.train = disabled_train
|
||||||
|
for name, param in ln_vision.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
ln_vision = ln_vision.eval()
|
||||||
|
ln_vision.train = disabled_train
|
||||||
|
logging.info("freeze vision encoder")
|
||||||
|
|
||||||
|
logging.info('Loading VIT Done')
|
||||||
return visual_encoder, ln_vision
|
return visual_encoder, ln_vision
|
||||||
|
|
||||||
|
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
|
||||||
|
**lora_kargs):
|
||||||
|
logging.info('Loading LLAMA')
|
||||||
|
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
|
||||||
|
llama_tokenizer.pad_token = "$$"
|
||||||
|
|
||||||
|
if low_resource:
|
||||||
|
llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
load_in_8bit=True,
|
||||||
|
device_map={'': low_res_device}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lora_r > 0:
|
||||||
|
llama_model = prepare_model_for_int8_training(llama_model)
|
||||||
|
loraconfig = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
**lora_kargs
|
||||||
|
)
|
||||||
|
llama_model = get_peft_model(llama_model, loraconfig)
|
||||||
|
|
||||||
|
llama_model.print_trainable_parameters()
|
||||||
|
|
||||||
|
else:
|
||||||
|
for name, param in llama_model.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
logging.info('Loading LLAMA Done')
|
||||||
|
return llama_model, llama_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def load_from_pretrained(self, url_or_filename):
|
def load_from_pretrained(self, url_or_filename):
|
||||||
if is_url(url_or_filename):
|
if is_url(url_or_filename):
|
||||||
cached_file = download_cached_file(
|
cached_file = download_cached_file(
|
||||||
@ -185,136 +227,6 @@ class BaseModel(nn.Module):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEncoder(nn.Module):
|
|
||||||
"""
|
|
||||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward_features(self, samples, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return list(self.parameters())[0].device
|
|
||||||
|
|
||||||
|
|
||||||
class SharedQueueMixin:
|
|
||||||
@torch.no_grad()
|
|
||||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
|
||||||
# gather keys before updating queue
|
|
||||||
image_feats = concat_all_gather(image_feat)
|
|
||||||
text_feats = concat_all_gather(text_feat)
|
|
||||||
|
|
||||||
batch_size = image_feats.shape[0]
|
|
||||||
|
|
||||||
ptr = int(self.queue_ptr)
|
|
||||||
assert self.queue_size % batch_size == 0 # for simplicity
|
|
||||||
|
|
||||||
# replace the keys at ptr (dequeue and enqueue)
|
|
||||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
|
||||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
|
||||||
|
|
||||||
if idxs is not None:
|
|
||||||
idxs = concat_all_gather(idxs)
|
|
||||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
|
||||||
|
|
||||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
|
||||||
self.queue_ptr[0] = ptr
|
|
||||||
|
|
||||||
|
|
||||||
class MomentumDistilationMixin:
|
|
||||||
@torch.no_grad()
|
|
||||||
def copy_params(self):
|
|
||||||
for model_pair in self.model_pairs:
|
|
||||||
for param, param_m in zip(
|
|
||||||
model_pair[0].parameters(), model_pair[1].parameters()
|
|
||||||
):
|
|
||||||
param_m.data.copy_(param.data) # initialize
|
|
||||||
param_m.requires_grad = False # not update by gradient
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _momentum_update(self):
|
|
||||||
for model_pair in self.model_pairs:
|
|
||||||
for param, param_m in zip(
|
|
||||||
model_pair[0].parameters(), model_pair[1].parameters()
|
|
||||||
):
|
|
||||||
param_m.data = param_m.data * self.momentum + param.data * (
|
|
||||||
1.0 - self.momentum
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GatherLayer(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Gather tensors from all workers with support for backward propagation:
|
|
||||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x):
|
|
||||||
output = [
|
|
||||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(output, x)
|
|
||||||
return tuple(output)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *grads):
|
|
||||||
all_gradients = torch.stack(grads)
|
|
||||||
torch.distributed.all_reduce(all_gradients)
|
|
||||||
return all_gradients[torch.distributed.get_rank()]
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_with_grad(tensors):
|
|
||||||
"""
|
|
||||||
Performs all_gather operation on the provided tensors.
|
|
||||||
Graph remains connected for backward grad computation.
|
|
||||||
"""
|
|
||||||
# Queue the gathered tensors
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
# There is no need for reduction in the single-proc case
|
|
||||||
if world_size == 1:
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
# tensor_all = GatherLayer.apply(tensors)
|
|
||||||
tensor_all = GatherLayer.apply(tensors)
|
|
||||||
|
|
||||||
return torch.cat(tensor_all, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def concat_all_gather(tensor):
|
|
||||||
"""
|
|
||||||
Performs all_gather operation on the provided tensors.
|
|
||||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
||||||
"""
|
|
||||||
# if use distributed training
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
tensors_gather = [
|
|
||||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
||||||
|
|
||||||
output = torch.cat(tensors_gather, dim=0)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def tile(x, dim, n_tile):
|
|
||||||
init_dim = x.size(dim)
|
|
||||||
repeat_idx = [1] * x.dim()
|
|
||||||
repeat_idx[dim] = n_tile
|
|
||||||
x = x.repeat(*(repeat_idx))
|
|
||||||
order_index = torch.LongTensor(
|
|
||||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
|
||||||
)
|
|
||||||
return torch.index_select(x, dim, order_index.to(x.device))
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
def disabled_train(self, mode=True):
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
does not change anymore."""
|
does not change anymore."""
|
||||||
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.models.base_model import BaseModel, disabled_train
|
from minigpt4.models.base_model import BaseModel, disabled_train
|
||||||
from minigpt4.models.minigpt_base import MiniGPTBase
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
@ -31,6 +32,99 @@ class MiniGPT4(MiniGPTBase):
|
|||||||
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||||
|
img_size=224,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
has_qformer=True,
|
||||||
|
freeze_qformer=True,
|
||||||
|
num_query_token=32,
|
||||||
|
llama_model="",
|
||||||
|
prompt_path="",
|
||||||
|
prompt_template="",
|
||||||
|
max_txt_len=32,
|
||||||
|
end_sym='\n',
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
end_sym=end_sym,
|
||||||
|
low_resource=low_resource,
|
||||||
|
device_8bit=device_8bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_qformer = has_qformer
|
||||||
|
if self.has_qformer:
|
||||||
|
print('Loading Q-Former')
|
||||||
|
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||||
|
num_query_token, self.visual_encoder.num_features, freeze_qformer
|
||||||
|
)
|
||||||
|
self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
|
||||||
|
|
||||||
|
img_f_dim = self.Qformer.config.hidden_size
|
||||||
|
print('Loading Q-Former Done')
|
||||||
|
else:
|
||||||
|
img_f_dim = self.visual_encoder.num_features * 4
|
||||||
|
print('Do not use Q-Former here.')
|
||||||
|
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
img_f_dim, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_path:
|
||||||
|
with open(prompt_path, 'r') as f:
|
||||||
|
raw_prompts = f.read().splitlines()
|
||||||
|
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||||
|
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||||
|
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||||
|
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||||
|
else:
|
||||||
|
self.prompt_list = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_Qformer(cls, num_query_token, vision_width, freeze):
|
||||||
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||||
|
encoder_config.encoder_width = vision_width
|
||||||
|
# insert cross-attention layer every other block
|
||||||
|
encoder_config.add_cross_attention = True
|
||||||
|
encoder_config.cross_attention_freq = 2
|
||||||
|
encoder_config.query_length = num_query_token
|
||||||
|
Qformer = BertLMHeadModel(config=encoder_config)
|
||||||
|
query_tokens = nn.Parameter(
|
||||||
|
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||||
|
)
|
||||||
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||||
|
|
||||||
|
Qformer.cls = None
|
||||||
|
Qformer.bert.embeddings.word_embeddings = None
|
||||||
|
Qformer.bert.embeddings.position_embeddings = None
|
||||||
|
for layer in Qformer.bert.encoder.layer:
|
||||||
|
layer.output = None
|
||||||
|
layer.intermediate = None
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
for name, param in Qformer.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
Qformer = Qformer.eval()
|
||||||
|
Qformer.train = disabled_train
|
||||||
|
query_tokens.requires_grad = False
|
||||||
|
logging.info("freeze Qformer")
|
||||||
|
|
||||||
|
return Qformer, query_tokens
|
||||||
|
|
||||||
def encode_img(self, image):
|
def encode_img(self, image):
|
||||||
device = image.device
|
device = image.device
|
||||||
|
|
||||||
@ -82,9 +176,6 @@ class MiniGPT4(MiniGPTBase):
|
|||||||
max_txt_len = cfg.get("max_txt_len", 32)
|
max_txt_len = cfg.get("max_txt_len", 32)
|
||||||
end_sym = cfg.get("end_sym", '\n')
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
lora_r = cfg.get("lora_r", 0)
|
|
||||||
lora_alpha = cfg.get("lora_alpha", 32)
|
|
||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
vit_model=vit_model,
|
vit_model=vit_model,
|
||||||
q_former_model=q_former_model,
|
q_former_model=q_former_model,
|
||||||
@ -103,8 +194,6 @@ class MiniGPT4(MiniGPTBase):
|
|||||||
end_sym=end_sym,
|
end_sym=end_sym,
|
||||||
low_resource=low_resource,
|
low_resource=low_resource,
|
||||||
device_8bit=device_8bit,
|
device_8bit=device_8bit,
|
||||||
lora_r=lora_r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
@ -13,9 +13,7 @@ from transformers import LlamaTokenizer
|
|||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model,
|
get_peft_model,
|
||||||
get_peft_model_state_dict,
|
|
||||||
prepare_model_for_int8_training,
|
prepare_model_for_int8_training,
|
||||||
set_peft_model_state_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -27,131 +25,41 @@ class MiniGPTBase(BaseModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vit_model="eva_clip_g",
|
vit_model="eva_clip_g",
|
||||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
|
||||||
img_size=224,
|
img_size=224,
|
||||||
drop_path_rate=0,
|
drop_path_rate=0,
|
||||||
use_grad_checkpoint=False,
|
use_grad_checkpoint=False,
|
||||||
vit_precision="fp16",
|
vit_precision="fp16",
|
||||||
freeze_vit=True,
|
freeze_vit=True,
|
||||||
has_qformer=True,
|
|
||||||
freeze_qformer=True,
|
|
||||||
num_query_token=32,
|
|
||||||
llama_model="",
|
llama_model="",
|
||||||
prompt_path="",
|
|
||||||
prompt_template="",
|
|
||||||
max_txt_len=32,
|
max_txt_len=32,
|
||||||
end_sym='\n',
|
end_sym='\n',
|
||||||
low_resource=False, # use 8 bit and put vit in cpu
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
lora_r=0,
|
lora_r=0, # lora_r means lora is not used
|
||||||
lora_target_modules=["q_proj", "v_proj"],
|
lora_target_modules=["q_proj", "v_proj"],
|
||||||
lora_alpha=16,
|
lora_alpha=16,
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tokenizer = self.init_tokenizer()
|
self.llama_model, self.llama_tokenizer = self.init_llm(
|
||||||
self.low_resource = low_resource
|
llama_model_path=llama_model,
|
||||||
|
low_resource=low_resource,
|
||||||
|
low_res_device=device_8bit,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
print('Loading VIT')
|
|
||||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
|
||||||
)
|
)
|
||||||
if freeze_vit:
|
|
||||||
for name, param in self.visual_encoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.visual_encoder = self.visual_encoder.eval()
|
|
||||||
self.visual_encoder.train = disabled_train
|
|
||||||
for name, param in self.ln_vision.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.ln_vision = self.ln_vision.eval()
|
|
||||||
self.ln_vision.train = disabled_train
|
|
||||||
logging.info("freeze vision encoder")
|
|
||||||
print('Loading VIT Done')
|
|
||||||
|
|
||||||
self.has_qformer = has_qformer
|
|
||||||
if self.has_qformer:
|
|
||||||
print('Loading Q-Former')
|
|
||||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
||||||
num_query_token, self.visual_encoder.num_features
|
|
||||||
)
|
|
||||||
self.Qformer.cls = None
|
|
||||||
self.Qformer.bert.embeddings.word_embeddings = None
|
|
||||||
self.Qformer.bert.embeddings.position_embeddings = None
|
|
||||||
for layer in self.Qformer.bert.encoder.layer:
|
|
||||||
layer.output = None
|
|
||||||
layer.intermediate = None
|
|
||||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
|
||||||
|
|
||||||
if freeze_qformer:
|
|
||||||
for name, param in self.Qformer.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.Qformer = self.Qformer.eval()
|
|
||||||
self.Qformer.train = disabled_train
|
|
||||||
self.query_tokens.requires_grad = False
|
|
||||||
logging.info("freeze Qformer")
|
|
||||||
|
|
||||||
img_f_dim = self.Qformer.config.hidden_size
|
|
||||||
print('Loading Q-Former Done')
|
|
||||||
else:
|
|
||||||
img_f_dim = self.visual_encoder.num_features * 4
|
|
||||||
print('Do not use Q-Former here.')
|
|
||||||
|
|
||||||
print('Loading LLAMA')
|
|
||||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
|
||||||
self.llama_tokenizer.pad_token = "$$"
|
|
||||||
|
|
||||||
if self.low_resource:
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
|
||||||
llama_model,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
load_in_8bit=True,
|
|
||||||
device_map={'': device_8bit}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
|
||||||
llama_model,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_r > 0:
|
|
||||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
|
||||||
loraconfig = LoraConfig(
|
|
||||||
r=lora_r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
target_modules=lora_target_modules,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
bias="none",
|
|
||||||
task_type="CAUSAL_LM"
|
|
||||||
)
|
|
||||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
|
||||||
|
|
||||||
# if ckpt_path:
|
|
||||||
# print('load the llm under lora')
|
|
||||||
# ckpt = torch.load(ckpt_path)
|
|
||||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
|
||||||
self.llama_model.print_trainable_parameters()
|
|
||||||
|
|
||||||
else:
|
|
||||||
for name, param in self.llama_model.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
print('Loading LLAMA Done')
|
|
||||||
|
|
||||||
self.llama_proj = nn.Linear(
|
|
||||||
img_f_dim, self.llama_model.config.hidden_size
|
|
||||||
)
|
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
|
|
||||||
if prompt_path:
|
self.prompt_list = []
|
||||||
with open(prompt_path, 'r') as f:
|
|
||||||
raw_prompts = f.read().splitlines()
|
|
||||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
|
||||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
|
||||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
|
||||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
|
||||||
else:
|
|
||||||
self.prompt_list = []
|
|
||||||
|
|
||||||
def vit_to_cpu(self):
|
def vit_to_cpu(self):
|
||||||
self.ln_vision.to("cpu")
|
self.ln_vision.to("cpu")
|
||||||
|
Loading…
Reference in New Issue
Block a user