From 90b7b002685ebdd36d01a4d9ffa7939097fef14b Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Thu, 12 Oct 2023 17:17:32 +0300 Subject: [PATCH] update minigpt_base.py expect the initailization function --- minigpt4/models/mini_gpt4.py | 280 +----------------- minigpt4/models/minigpt_base.py | 497 ++++++++++++++++++++++++++++++++ 2 files changed, 503 insertions(+), 274 deletions(-) create mode 100644 minigpt4/models/minigpt_base.py diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 2575f32..d1ff071 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -7,6 +7,7 @@ import torch.nn as nn from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel, disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import LlamaTokenizer @@ -20,7 +21,7 @@ from peft import ( @registry.register_model("mini_gpt4") -class MiniGPT4(BaseModel): +class MiniGPT4(MiniGPTBase): """ MiniGPT-4 model """ @@ -30,146 +31,11 @@ class MiniGPT4(BaseModel): "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", } - def __init__( - self, - vit_model="eva_clip_g", - q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", - img_size=224, - drop_path_rate=0, - use_grad_checkpoint=False, - vit_precision="fp16", - freeze_vit=True, - has_qformer=True, - freeze_qformer=True, - num_query_token=32, - llama_model="", - prompt_path="", - prompt_template="", - max_txt_len=32, - end_sym='\n', - low_resource=False, # use 8 bit and put vit in cpu - device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - lora_r=0, - lora_target_modules=["q_proj", "v_proj"], - lora_alpha=16, - lora_dropout=0.05, - ): - super().__init__() - - self.tokenizer = self.init_tokenizer() - self.low_resource = low_resource - - print('Loading VIT') - self.visual_encoder, self.ln_vision = self.init_vision_encoder( - vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision - ) - if freeze_vit: - for name, param in self.visual_encoder.named_parameters(): - param.requires_grad = False - self.visual_encoder = self.visual_encoder.eval() - self.visual_encoder.train = disabled_train - for name, param in self.ln_vision.named_parameters(): - param.requires_grad = False - self.ln_vision = self.ln_vision.eval() - self.ln_vision.train = disabled_train - logging.info("freeze vision encoder") - print('Loading VIT Done') - - self.has_qformer = has_qformer - if self.has_qformer: - print('Loading Q-Former') - self.Qformer, self.query_tokens = self.init_Qformer( - num_query_token, self.visual_encoder.num_features - ) - self.Qformer.cls = None - self.Qformer.bert.embeddings.word_embeddings = None - self.Qformer.bert.embeddings.position_embeddings = None - for layer in self.Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - self.load_from_pretrained(url_or_filename=q_former_model) - - if freeze_qformer: - for name, param in self.Qformer.named_parameters(): - param.requires_grad = False - self.Qformer = self.Qformer.eval() - self.Qformer.train = disabled_train - self.query_tokens.requires_grad = False - logging.info("freeze Qformer") - - img_f_dim = self.Qformer.config.hidden_size - print('Loading Q-Former Done') - else: - img_f_dim = self.visual_encoder.num_features * 4 - print('Do not use Q-Former here.') - - print('Loading LLAMA') - self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) - self.llama_tokenizer.pad_token = "$$" - - if self.low_resource: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - load_in_8bit=True, - device_map={'': device_8bit} - ) - else: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - ) - - if lora_r > 0: - self.llama_model = prepare_model_for_int8_training(self.llama_model) - loraconfig = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_target_modules, - lora_dropout=lora_dropout, - bias="none", - task_type="CAUSAL_LM" - ) - self.llama_model = get_peft_model(self.llama_model, loraconfig) - - # if ckpt_path: - # print('load the llm under lora') - # ckpt = torch.load(ckpt_path) - # set_peft_model_state_dict(self.llama_model,ckpt) - self.llama_model.print_trainable_parameters() - - else: - for name, param in self.llama_model.named_parameters(): - param.requires_grad = False - print('Loading LLAMA Done') - - self.llama_proj = nn.Linear( - img_f_dim, self.llama_model.config.hidden_size - ) - self.max_txt_len = max_txt_len - self.end_sym = end_sym - - if prompt_path: - with open(prompt_path, 'r') as f: - raw_prompts = f.read().splitlines() - filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] - self.prompt_list = [prompt_template.format(p) for p in filted_prompts] - print('Load {} training prompts'.format(len(self.prompt_list))) - print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) - else: - self.prompt_list = [] - - def vit_to_cpu(self): - self.ln_vision.to("cpu") - self.ln_vision.float() - self.visual_encoder.to("cpu") - self.visual_encoder.float() - def encode_img(self, image): device = image.device - if self.low_resource: - self.vit_to_cpu() - image = image.to("cpu") + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) @@ -194,140 +60,6 @@ class MiniGPT4(BaseModel): atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) return inputs_llama, atts_llama - def get_context_emb(self, prompt, img_list): - device = img_list[0].device - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] - - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - def prompt_wrap(self, img_embeds, atts_img, prompts): - if prompts: - emb_lists = [] - if isinstance(prompts, str): - prompts = [prompts] * len(img_embeds) - - for each_img_embed, each_prompt in zip(img_embeds, prompts): - p_before, p_after = each_prompt.split('') - - p_before_tokens = self.llama_tokenizer( - p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_after_tokens = self.llama_tokenizer( - p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_before_embed = self.embed_tokens(p_before_tokens.input_ids) - p_after_embed = self.embed_tokens(p_after_tokens.input_ids) - wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1) - emb_lists.append(wrapped_emb) - emb_lens = [emb.shape[1] for emb in emb_lists] - pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) - wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() - wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device) - for i, emb in enumerate(emb_lists): - wrapped_embs[i, :emb_lens[i]] = emb - wrapped_atts[i, :emb_lens[i]] = 1 - return wrapped_embs, wrapped_atts - else: - return img_embeds, atts_img - - def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): - input_lens = [] - cat_embs = [] - cat_atts = [] - for i in range(input_embs.size(0)): - input_len = input_atts[i].sum() - input_lens.append(input_len) - cat_embs.append( - torch.cat([ - input_embs[i][:input_len], - output_embs[i], - input_embs[i][input_len:] - ]) - ) - cat_atts.append( - torch.cat([ - input_atts[i][:input_len], - output_atts[i], - input_atts[i][input_len:] - ]) - ) - cat_embs = torch.stack(cat_embs) - cat_atts = torch.stack(cat_atts) - return cat_embs, cat_atts, input_lens - - def forward(self, samples): - image = samples["image"] - img_embeds, atts_img = self.encode_img(image) - - if self.prompt_list: - instruction = random.choice(self.prompt_list) - else: - instruction = samples["instruction_input"] if "instruction_input" in samples else None - - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction) - - self.llama_tokenizer.padding_side = "right" - text = [t + self.end_sym for t in samples["answer"]] - - to_regress_tokens = self.llama_tokenizer( - text, - return_tensors="pt", - padding="longest", - truncation=True, - max_length=self.max_txt_len, - add_special_tokens=False - ).to(image.device) - - batch_size = img_embeds.shape[0] - bos = torch.ones([batch_size, 1], - dtype=to_regress_tokens.input_ids.dtype, - device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id - bos_embeds = self.embed_tokens(bos) - atts_bos = atts_img[:, :1] - - to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids) - inputs_embeds, attention_mask, input_lens = \ - self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask) - inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) - attention_mask = torch.cat([atts_bos, attention_mask], dim=1) - - part_targets = to_regress_tokens.input_ids.masked_fill( - to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 - ) - targets = ( - torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], - dtype=torch.long).to(image.device).fill_(-100) - ) - - for i, target in enumerate(part_targets): - targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos - - with self.maybe_autocast(): - outputs = self.llama_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - return_dict=True, - labels=targets, - ) - loss = outputs.loss - - return {"loss": loss} - - def embed_tokens(self, token_ids): - if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model - embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) - else: - embeds = self.llama_model.base_model.embed_tokens(token_ids) - return embeds - @classmethod def from_config(cls, cfg): vit_model = cfg.get("vit_model", "eva_clip_g") @@ -377,7 +109,7 @@ class MiniGPT4(BaseModel): ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 if ckpt_path: - print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) + print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) ckpt = torch.load(ckpt_path, map_location="cpu") msg = model.load_state_dict(ckpt['model'], strict=False) diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py new file mode 100644 index 0000000..7c4008d --- /dev/null +++ b/minigpt4/models/minigpt_base.py @@ -0,0 +1,497 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import BaseModel, disabled_train +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers import LlamaTokenizer + +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, + set_peft_model_state_dict, +) + + +class MiniGPTBase(BaseModel): + """ + Base class for MiniGPT-4 and MiniGPT-v2 + """ + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + has_qformer=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + lora_r=0, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + + print('Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print('Loading VIT Done') + + self.has_qformer = has_qformer + if self.has_qformer: + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + logging.info("freeze Qformer") + + img_f_dim = self.Qformer.config.hidden_size + print('Loading Q-Former Done') + else: + img_f_dim = self.visual_encoder.num_features * 4 + print('Do not use Q-Former here.') + + print('Loading LLAMA') + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) + self.llama_tokenizer.pad_token = "$$" + + if self.low_resource: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': device_8bit} + ) + else: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + + if lora_r > 0: + self.llama_model = prepare_model_for_int8_training(self.llama_model) + loraconfig = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + self.llama_model = get_peft_model(self.llama_model, loraconfig) + + # if ckpt_path: + # print('load the llm under lora') + # ckpt = torch.load(ckpt_path) + # set_peft_model_state_dict(self.llama_model,ckpt) + self.llama_model.print_trainable_parameters() + + else: + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + print('Loading LLAMA Done') + + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + def vit_to_cpu(self): + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + + def get_context_emb(self, prompt, img_list): + device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + else: + # return the multi-modal embedding in right padding + emb_lists = [] + if isinstance(prompts, str): + prompts = [prompts] * len(img_embeds) + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1)) + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer( + p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + return wrapped_embs, wrapped_atts + + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + input_lens = [] + cat_embs = [] + cat_atts = [] + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def tokenize_conversation(self, conv_q, conv_a): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + to_regress_token_ids_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + questions = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it + answers = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in answers] + cur_id = [] + cur_target = [] + for i in range(len(questions)): + cur_id.append(answers[i].input_ids) + cur_target.append(answers[i].input_ids) + cur_id.append(questions[i].input_ids) + cur_target.append(torch.ones_like(questions[i].input_ids) * -100) + + cur_id.append(answers[-1].input_ids) + cur_target.append(answers[-1].input_ids) + + cur_id = torch.cat(cur_id, dim=1) + cur_target = torch.cat(cur_target, dim=1) + to_regress_token_ids_list.append(cur_id) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + to_regress_token_ids = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id + targets = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * -100 + for batch_idx in range(batch_size): + cur_len = to_regress_token_ids_list[batch_idx].shape[1] + to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len] + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int) + + return to_regress_token_ids, to_regress_token_attn, targets + + def preparing_embedding(self, samples): + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + + conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] + + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) + regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) + + else: + if "instruction_input" in samples: + instruction = samples["instruction_input"] + elif self.prompt_list: + instruction = random.choice(self.prompt_list) + else: + instruction = None + + if self.chat_template: + instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction='mean'): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = cond_atts[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + def embed_tokens(self, token_ids): + if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + else: + embeds = self.llama_model.base_model.embed_tokens(token_ids) + return embeds + + + @torch.no_grad() + def generate( + self, + images, + texts, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + img_embeds, atts_img = self.encode_img(images.to(self.device)) + image_lists = [[image_emb[None]] for image_emb in img_embeds] + + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + length_penalty=length_penalty, + temperature=temperature, + do_sample=do_sample, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + + return answers + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() \ No newline at end of file