mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
update minigpt_base.py expect the initailization function
This commit is contained in:
parent
4ce9c5febe
commit
90b7b00268
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel, disabled_train
|
||||
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
@ -20,7 +21,7 @@ from peft import (
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4")
|
||||
class MiniGPT4(BaseModel):
|
||||
class MiniGPT4(MiniGPTBase):
|
||||
"""
|
||||
MiniGPT-4 model
|
||||
"""
|
||||
@ -30,146 +31,11 @@ class MiniGPT4(BaseModel):
|
||||
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
has_qformer=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
lora_r=0,
|
||||
lora_target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
if freeze_vit:
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print('Loading VIT Done')
|
||||
|
||||
self.has_qformer = has_qformer
|
||||
if self.has_qformer:
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
|
||||
img_f_dim = self.Qformer.config.hidden_size
|
||||
print('Loading Q-Former Done')
|
||||
else:
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
print('Do not use Q-Former here.')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
if self.low_resource:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map={'': device_8bit}
|
||||
)
|
||||
else:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if lora_r > 0:
|
||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||
|
||||
# if ckpt_path:
|
||||
# print('load the llm under lora')
|
||||
# ckpt = torch.load(ckpt_path)
|
||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||
self.llama_model.print_trainable_parameters()
|
||||
|
||||
else:
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def vit_to_cpu(self):
|
||||
self.ln_vision.to("cpu")
|
||||
self.ln_vision.float()
|
||||
self.visual_encoder.to("cpu")
|
||||
self.visual_encoder.float()
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
if self.low_resource:
|
||||
self.vit_to_cpu()
|
||||
image = image.to("cpu")
|
||||
|
||||
if len(image.shape) > 4:
|
||||
image = image.reshape(-1, *image.shape[-3:])
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
@ -194,140 +60,6 @@ class MiniGPT4(BaseModel):
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def get_context_emb(self, prompt, img_list):
|
||||
device = img_list[0].device
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompts):
|
||||
if prompts:
|
||||
emb_lists = []
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts] * len(img_embeds)
|
||||
|
||||
for each_img_embed, each_prompt in zip(img_embeds, prompts):
|
||||
p_before, p_after = each_prompt.split('<ImageHere>')
|
||||
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
|
||||
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
|
||||
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
|
||||
emb_lists.append(wrapped_emb)
|
||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
|
||||
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
|
||||
for i, emb in enumerate(emb_lists):
|
||||
wrapped_embs[i, :emb_lens[i]] = emb
|
||||
wrapped_atts[i, :emb_lens[i]] = 1
|
||||
return wrapped_embs, wrapped_atts
|
||||
else:
|
||||
return img_embeds, atts_img
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
input_lens = []
|
||||
cat_embs = []
|
||||
cat_atts = []
|
||||
for i in range(input_embs.size(0)):
|
||||
input_len = input_atts[i].sum()
|
||||
input_lens.append(input_len)
|
||||
cat_embs.append(
|
||||
torch.cat([
|
||||
input_embs[i][:input_len],
|
||||
output_embs[i],
|
||||
input_embs[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_atts.append(
|
||||
torch.cat([
|
||||
input_atts[i][:input_len],
|
||||
output_atts[i],
|
||||
input_atts[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_embs = torch.stack(cat_embs)
|
||||
cat_atts = torch.stack(cat_atts)
|
||||
return cat_embs, cat_atts, input_lens
|
||||
|
||||
def forward(self, samples):
|
||||
image = samples["image"]
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
|
||||
if self.prompt_list:
|
||||
instruction = random.choice(self.prompt_list)
|
||||
else:
|
||||
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
|
||||
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
text = [t + self.end_sym for t in samples["answer"]]
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(image.device)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones([batch_size, 1],
|
||||
dtype=to_regress_tokens.input_ids.dtype,
|
||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.embed_tokens(bos)
|
||||
atts_bos = atts_img[:, :1]
|
||||
|
||||
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
|
||||
inputs_embeds, attention_mask, input_lens = \
|
||||
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
|
||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
|
||||
|
||||
part_targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
targets = (
|
||||
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||
dtype=torch.long).to(image.device).fill_(-100)
|
||||
)
|
||||
|
||||
for i, target in enumerate(part_targets):
|
||||
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
def embed_tokens(self, token_ids):
|
||||
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||
else:
|
||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||
return embeds
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
@ -377,7 +109,7 @@ class MiniGPT4(BaseModel):
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
if ckpt_path:
|
||||
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
||||
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||
|
||||
|
497
minigpt4/models/minigpt_base.py
Normal file
497
minigpt4/models/minigpt_base.py
Normal file
@ -0,0 +1,497 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast as autocast
|
||||
import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel, disabled_train
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
|
||||
|
||||
class MiniGPTBase(BaseModel):
|
||||
"""
|
||||
Base class for MiniGPT-4 and MiniGPT-v2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
has_qformer=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
lora_r=0,
|
||||
lora_target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
if freeze_vit:
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print('Loading VIT Done')
|
||||
|
||||
self.has_qformer = has_qformer
|
||||
if self.has_qformer:
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
|
||||
img_f_dim = self.Qformer.config.hidden_size
|
||||
print('Loading Q-Former Done')
|
||||
else:
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
print('Do not use Q-Former here.')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
if self.low_resource:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map={'': device_8bit}
|
||||
)
|
||||
else:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if lora_r > 0:
|
||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||
|
||||
# if ckpt_path:
|
||||
# print('load the llm under lora')
|
||||
# ckpt = torch.load(ckpt_path)
|
||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||
self.llama_model.print_trainable_parameters()
|
||||
|
||||
else:
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def vit_to_cpu(self):
|
||||
self.ln_vision.to("cpu")
|
||||
self.ln_vision.float()
|
||||
self.visual_encoder.to("cpu")
|
||||
self.visual_encoder.float()
|
||||
|
||||
def get_context_emb(self, prompt, img_list):
|
||||
device = img_list[0].device
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||
if prompts is None or len(prompts) == 0:
|
||||
# prompts is not provided, just return the original image embedding
|
||||
return img_embeds, atts_img
|
||||
elif img_embeds is None:
|
||||
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
prompt_tokens = self.llama_tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||
atts_prompt = prompt_tokens.attention_mask
|
||||
return prompt_embeds, atts_prompt
|
||||
else:
|
||||
# return the multi-modal embedding in right padding
|
||||
emb_lists = []
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts] * len(img_embeds)
|
||||
|
||||
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||
pn = each_img_embed.shape[-2]
|
||||
if lengths is not None:
|
||||
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||
p_segs = each_prompt.split('<ImageHere>')
|
||||
interleave_emb = []
|
||||
for idx, seg in enumerate(p_segs[:-1]):
|
||||
p_tokens = self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
|
||||
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||
p_tokens = self.llama_tokenizer(
|
||||
p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||
wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
|
||||
emb_lists.append(wrapped_emb)
|
||||
|
||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||
|
||||
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||
|
||||
for i, emb in enumerate(emb_lists):
|
||||
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||
wrapped_embs[i, :length] = emb[:, :length]
|
||||
wrapped_atts[i, :length] = 1
|
||||
return wrapped_embs, wrapped_atts
|
||||
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
"""
|
||||
Concatenate the batched input embedding and batched output embedding together.
|
||||
Both the input and the output embedding should be right padded.
|
||||
"""
|
||||
input_lens = []
|
||||
cat_embs = []
|
||||
cat_atts = []
|
||||
for i in range(input_embs.size(0)):
|
||||
input_len = input_atts[i].sum()
|
||||
input_lens.append(input_len)
|
||||
cat_embs.append(
|
||||
torch.cat([
|
||||
input_embs[i][:input_len],
|
||||
output_embs[i],
|
||||
input_embs[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_atts.append(
|
||||
torch.cat([
|
||||
input_atts[i][:input_len],
|
||||
output_atts[i],
|
||||
input_atts[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_embs = torch.stack(cat_embs)
|
||||
cat_atts = torch.stack(cat_atts)
|
||||
return cat_embs, cat_atts, input_lens
|
||||
|
||||
def tokenize_conversation(self, conv_q, conv_a):
|
||||
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||
|
||||
to_regress_token_ids_list = []
|
||||
targets_list = []
|
||||
|
||||
batch_size = len(conv_q)
|
||||
for batch_idx in range(batch_size):
|
||||
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||
questions = [self.llama_tokenizer(q,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
|
||||
answers = [self.llama_tokenizer(q,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False).to(self.device) for q in answers]
|
||||
cur_id = []
|
||||
cur_target = []
|
||||
for i in range(len(questions)):
|
||||
cur_id.append(answers[i].input_ids)
|
||||
cur_target.append(answers[i].input_ids)
|
||||
cur_id.append(questions[i].input_ids)
|
||||
cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
|
||||
|
||||
cur_id.append(answers[-1].input_ids)
|
||||
cur_target.append(answers[-1].input_ids)
|
||||
|
||||
cur_id = torch.cat(cur_id, dim=1)
|
||||
cur_target = torch.cat(cur_target, dim=1)
|
||||
to_regress_token_ids_list.append(cur_id)
|
||||
targets_list.append(cur_target)
|
||||
|
||||
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||
to_regress_token_ids = torch.ones([batch_size, max_len],
|
||||
dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
|
||||
targets = torch.ones([batch_size, max_len],
|
||||
dtype=cur_id.dtype, device=self.device) * -100
|
||||
for batch_idx in range(batch_size):
|
||||
cur_len = to_regress_token_ids_list[batch_idx].shape[1]
|
||||
to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
|
||||
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||
|
||||
to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
|
||||
|
||||
return to_regress_token_ids, to_regress_token_attn, targets
|
||||
|
||||
def preparing_embedding(self, samples):
|
||||
### prepare input tokens
|
||||
if 'image' in samples:
|
||||
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||
else:
|
||||
img_embeds = img_atts = None
|
||||
|
||||
if 'conv_q' in samples:
|
||||
# handeling conversation datasets
|
||||
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||
|
||||
connect_sym = samples['connect_sym'][0]
|
||||
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||
|
||||
conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
|
||||
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
||||
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
||||
|
||||
else:
|
||||
if "instruction_input" in samples:
|
||||
instruction = samples["instruction_input"]
|
||||
elif self.prompt_list:
|
||||
instruction = random.choice(self.prompt_list)
|
||||
else:
|
||||
instruction = None
|
||||
|
||||
if self.chat_template:
|
||||
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
||||
|
||||
if 'length' in samples:
|
||||
# the input is a image train (like videos)
|
||||
bsz, pn, hs = img_embeds.shape
|
||||
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||
else:
|
||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||
|
||||
### prepare target tokens
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
text = [t + self.end_sym for t in samples["answer"]]
|
||||
|
||||
regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(self.device)
|
||||
|
||||
regress_token_ids = regress_tokens.input_ids
|
||||
regress_atts = regress_tokens.attention_mask
|
||||
part_targets = regress_token_ids.masked_fill(
|
||||
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||
|
||||
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||
|
||||
def forward(self, samples, reduction='mean'):
|
||||
# prepare the embedding to condition and the embedding to regress
|
||||
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||
self.preparing_embedding(samples)
|
||||
|
||||
# concat the embedding to condition and the embedding to regress
|
||||
inputs_embeds, attention_mask, input_lens = \
|
||||
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||
|
||||
# get bos token embedding
|
||||
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.embed_tokens(bos)
|
||||
bos_atts = cond_atts[:, :1]
|
||||
|
||||
# add bos token at the begining
|
||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||
|
||||
# ensemble the final targets
|
||||
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||
dtype=torch.long).to(self.device).fill_(-100)
|
||||
|
||||
for i, target in enumerate(part_targets):
|
||||
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
reduction=reduction
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
def embed_tokens(self, token_ids):
|
||||
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||
else:
|
||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||
return embeds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
images,
|
||||
texts,
|
||||
num_beams=1,
|
||||
max_new_tokens=20,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1,
|
||||
length_penalty=1,
|
||||
temperature=1,
|
||||
do_sample=False,
|
||||
stop_words_ids=[2],
|
||||
):
|
||||
'''
|
||||
function for generate test use
|
||||
'''
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||
|
||||
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||
|
||||
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||
|
||||
batch_size = len(batch_embs)
|
||||
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||
emb_dim = batch_embs[0].shape[2]
|
||||
dtype = batch_embs[0].dtype
|
||||
device = batch_embs[0].device
|
||||
|
||||
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||
for i, emb in enumerate(batch_embs):
|
||||
emb_len = emb.shape[1]
|
||||
embs[i, -emb_len:] = emb[0]
|
||||
attn_mask[i, -emb_len:] = 1
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
attention_mask=attn_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty
|
||||
# stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
answers = []
|
||||
for output_token in outputs:
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||
output_texts = output_texts.replace("<s>", "")
|
||||
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||
answers.append(output_texts)
|
||||
|
||||
return answers
|
||||
|
||||
@torch.no_grad()
|
||||
def multi_select(self, images, texts, answers, num_cand=None):
|
||||
all_losses = []
|
||||
for answer in answers:
|
||||
choice_samples = {
|
||||
'image': images,
|
||||
'instruction_input': texts,
|
||||
'answer': answer
|
||||
}
|
||||
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||
all_losses.append(loss)
|
||||
torch.cuda.empty_cache()
|
||||
all_losses = torch.cat(all_losses, dim=-1)
|
||||
if num_cand is not None:
|
||||
for i in range(all_losses.shape[0]):
|
||||
all_losses[i, num_cand[i]:] = 9999
|
||||
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||
return output_class_ranks.tolist()
|
Loading…
Reference in New Issue
Block a user