update 0304
94
evaluation/coco_caption.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from pycocoevalcap.eval import COCOEvalCap
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
class COCO_Annotation:
|
||||||
|
def __init__(self, annotation_file):
|
||||||
|
self.coco_cn_file = annotation_file
|
||||||
|
self.imgToAnns = self.build_imgToAnns()
|
||||||
|
|
||||||
|
def build_imgToAnns(self):
|
||||||
|
imgToAnns = defaultdict(list)
|
||||||
|
with open(self.coco_cn_file, "r", encoding="UTF-8") as fin:
|
||||||
|
for line in fin:
|
||||||
|
line = line.strip()
|
||||||
|
temp = eval(line)
|
||||||
|
annotations = temp['annotations']
|
||||||
|
for ann in annotations:
|
||||||
|
image_id = str(ann['image_id']).zfill(6)
|
||||||
|
imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']})
|
||||||
|
return imgToAnns
|
||||||
|
|
||||||
|
def getImgIds(self):
|
||||||
|
return self.imgToAnns.keys()
|
||||||
|
|
||||||
|
class COCO_Result:
|
||||||
|
def __init__(self,result_file):
|
||||||
|
self.coco_cn_file = result_file
|
||||||
|
self.imgToAnns = self.build_imgToAnns()
|
||||||
|
|
||||||
|
def build_imgToAnns(self):
|
||||||
|
imgToAnns = dict()
|
||||||
|
data = json.load(open(self.coco_cn_file, "r"))
|
||||||
|
for d in data:
|
||||||
|
tmp = {
|
||||||
|
'image_id':d['question_id'][-6:],
|
||||||
|
'caption':d['answer']
|
||||||
|
}
|
||||||
|
imgToAnns[d['question_id'][-6:]] = [tmp]
|
||||||
|
return imgToAnns
|
||||||
|
|
||||||
|
def coco_caption_eval(results_file, split_name):
|
||||||
|
files = {
|
||||||
|
"val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json",
|
||||||
|
"test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# create coco object and coco_result object
|
||||||
|
annotation_file = files[split_name]
|
||||||
|
coco = COCO_Annotation(annotation_file)
|
||||||
|
coco_result = COCO_Result(results_file)
|
||||||
|
|
||||||
|
# create coco_eval object by taking coco and coco_result
|
||||||
|
coco_eval = COCOEvalCap(coco, coco_result)
|
||||||
|
|
||||||
|
# evaluate on a subset of images by setting
|
||||||
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||||
|
# please remove this line when evaluating the full validation set
|
||||||
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||||
|
|
||||||
|
# evaluate results
|
||||||
|
# SPICE will take a few minutes the first time, but speeds up due to caching
|
||||||
|
coco_eval.evaluate()
|
||||||
|
|
||||||
|
# print output evaluation scores
|
||||||
|
for metric, score in coco_eval.eval.items():
|
||||||
|
print(f"{metric}: {score:.3f}")
|
||||||
|
|
||||||
|
return coco_eval
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
result_file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_cap_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0302/20240302231/result/val_vqa_result_coco_cap.json"
|
||||||
|
split_name = "val"
|
||||||
|
coco_val = coco_caption_eval(result_file, split_name)
|
||||||
|
|
||||||
|
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
|
||||||
|
|
||||||
|
# log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
|
||||||
|
# with open(
|
||||||
|
# os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
||||||
|
# ) as f:
|
||||||
|
# f.write(json.dumps(log_stats) + "\n")
|
||||||
|
|
||||||
|
coco_res = {k: v for k, v in coco_val.eval.items()}
|
||||||
|
coco_res["agg_metrics"] = agg_metrics
|
||||||
|
|
||||||
|
print(coco_res)
|
||||||
|
|
||||||
|
|
||||||
|
main()
|
Before Width: | Height: | Size: 380 KiB |
Before Width: | Height: | Size: 457 KiB |
Before Width: | Height: | Size: 538 KiB |
Before Width: | Height: | Size: 586 KiB |
Before Width: | Height: | Size: 679 KiB |
Before Width: | Height: | Size: 555 KiB |
Before Width: | Height: | Size: 468 KiB |
Before Width: | Height: | Size: 658 KiB |
Before Width: | Height: | Size: 690 KiB |
Before Width: | Height: | Size: 586 KiB |
Before Width: | Height: | Size: 713 KiB |
Before Width: | Height: | Size: 597 KiB |
Before Width: | Height: | Size: 190 KiB |
Before Width: | Height: | Size: 603 KiB |
Before Width: | Height: | Size: 634 KiB |
Before Width: | Height: | Size: 249 KiB |
Before Width: | Height: | Size: 305 KiB |
Before Width: | Height: | Size: 588 KiB |
Before Width: | Height: | Size: 805 KiB |
Before Width: | Height: | Size: 853 KiB |
Before Width: | Height: | Size: 567 KiB |
Before Width: | Height: | Size: 712 KiB |
Before Width: | Height: | Size: 519 KiB |
Before Width: | Height: | Size: 565 KiB |
Before Width: | Height: | Size: 91 KiB |
Before Width: | Height: | Size: 83 KiB |
Before Width: | Height: | Size: 1.5 MiB |
Before Width: | Height: | Size: 1.2 MiB |
Before Width: | Height: | Size: 92 KiB |
Before Width: | Height: | Size: 25 KiB |
Before Width: | Height: | Size: 116 KiB |
Before Width: | Height: | Size: 865 KiB |
@ -16,11 +16,16 @@ datasets:
|
|||||||
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
|
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
|
||||||
storage:
|
storage:
|
||||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_train.json
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_train.json
|
||||||
# val:
|
val:
|
||||||
# url:
|
url:
|
||||||
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json
|
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json
|
||||||
# storage:
|
storage:
|
||||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_val.json
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_val.json
|
||||||
|
test:
|
||||||
|
url:
|
||||||
|
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json
|
||||||
|
storage:
|
||||||
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_val.json
|
||||||
# test:
|
# test:
|
||||||
# url:
|
# url:
|
||||||
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json
|
# - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json
|
||||||
|
@ -17,14 +17,14 @@ datasets:
|
|||||||
# md5: aa31ac474cf6250ebb81d18348a07ed8
|
# md5: aa31ac474cf6250ebb81d18348a07ed8
|
||||||
storage:
|
storage:
|
||||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
|
||||||
# val:
|
val:
|
||||||
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
|
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
|
||||||
# storage:
|
storage:
|
||||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
|
||||||
# test:
|
test:
|
||||||
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
|
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
|
||||||
# storage:
|
storage:
|
||||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
|
||||||
|
|
||||||
images:
|
images:
|
||||||
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO
|
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO
|
||||||
|
26
minigpt4/configs/datasets/coco/caption_eval.yaml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
coco_caption: # name of the dataset builder
|
||||||
|
# dataset_card: dataset_card/coco_caption.md
|
||||||
|
# data_dir: ${env.data_dir}/datasets
|
||||||
|
data_type: images # [images|videos|features]
|
||||||
|
|
||||||
|
build_info:
|
||||||
|
# Be careful not to append minus sign (-) before split to avoid itemizing
|
||||||
|
annotations:
|
||||||
|
val:
|
||||||
|
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
|
||||||
|
storage:
|
||||||
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
|
||||||
|
test:
|
||||||
|
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
|
||||||
|
storage:
|
||||||
|
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
|
||||||
|
|
||||||
|
images:
|
||||||
|
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO
|
||||||
|
|
@ -14,7 +14,7 @@ from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObj
|
|||||||
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
|
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
|
||||||
from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
|
from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
|
||||||
from minigpt4.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset
|
from minigpt4.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset
|
||||||
from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
|
from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset
|
||||||
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset
|
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset
|
||||||
from minigpt4.datasets.datasets.ok_vqa_datasets import OKVQADataset, OKVQAEvalDataset
|
from minigpt4.datasets.datasets.ok_vqa_datasets import OKVQADataset, OKVQAEvalDataset
|
||||||
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
||||||
@ -384,7 +384,7 @@ class OKVQABuilder(COCOVQABuilder):
|
|||||||
@registry.register_builder("aok_vqa")
|
@registry.register_builder("aok_vqa")
|
||||||
class AOKVQABuilder(BaseDatasetBuilder):
|
class AOKVQABuilder(BaseDatasetBuilder):
|
||||||
train_dataset_cls = AOKVQADataset
|
train_dataset_cls = AOKVQADataset
|
||||||
eval_dataset_cls = AOKVQADataset
|
eval_dataset_cls = AOKVQAEvalDataset
|
||||||
|
|
||||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
|
||||||
|
|
||||||
@ -584,6 +584,7 @@ class COCOCapBuilder(BaseDatasetBuilder):
|
|||||||
|
|
||||||
DATASET_CONFIG_DICT = {
|
DATASET_CONFIG_DICT = {
|
||||||
"default": "configs/datasets/coco/caption.yaml",
|
"default": "configs/datasets/coco/caption.yaml",
|
||||||
|
"coco_cap_eval": "configs/datasets/coco/caption_eval.yaml",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import torch
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset
|
from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
|
||||||
|
|
||||||
|
|
||||||
class __DisplMixin:
|
class __DisplMixin:
|
||||||
@ -37,11 +37,11 @@ class AOKVQADataset(VQADataset, __DisplMixin):
|
|||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
self.instruction_pool =[
|
self.instruction_pool =[
|
||||||
'{}',
|
'{} Choose from {}.',
|
||||||
'Q: {} A: ',
|
'Q: {} Multi Choices: {} A: ',
|
||||||
'Based on the image, respond to this question with a short answer: {}',
|
'Question: {} Multi Choices: {} Answer: ',
|
||||||
'{} A short answer to the question is ',
|
"{} Choose one from the following possible answers: {}. ",
|
||||||
'Question: {} Short answer:',
|
'{} Choose from {}. The answer is',
|
||||||
]
|
]
|
||||||
|
|
||||||
exist_annotation = []
|
exist_annotation = []
|
||||||
@ -63,25 +63,19 @@ class AOKVQADataset(VQADataset, __DisplMixin):
|
|||||||
image = self.vis_processor(image)
|
image = self.vis_processor(image)
|
||||||
question = self.text_processor(ann["question"])
|
question = self.text_processor(ann["question"])
|
||||||
|
|
||||||
answer_key = "direct_answers"
|
answer_lst = ann["choices"]
|
||||||
|
direct_answers = ann["direct_answers"]
|
||||||
answer_weight = {}
|
final_answer = random.choices(direct_answers, k=1)[0]
|
||||||
for answer in ann[answer_key]:
|
for answer in answer_lst:
|
||||||
if answer in answer_weight.keys():
|
if answer in direct_answers:
|
||||||
answer_weight[answer] += 1 / len(ann[answer_key])
|
final_answer = answer
|
||||||
else:
|
|
||||||
answer_weight[answer] = 1 / len(ann[answer_key])
|
|
||||||
|
|
||||||
answers = list(answer_weight.keys())
|
|
||||||
weights = list(answer_weight.values())
|
|
||||||
|
|
||||||
answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
||||||
"image_id": ann["image"],
|
"image_id": ann["image"],
|
||||||
"question": question,
|
"question": question,
|
||||||
"answer": answer,
|
"answer": final_answer,
|
||||||
|
"choices": ", ".join(answer_lst)
|
||||||
}
|
}
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@ -90,7 +84,7 @@ class AOKVQADataset(VQADataset, __DisplMixin):
|
|||||||
|
|
||||||
answer = self.text_processor(data['answer'])
|
answer = self.text_processor(data['answer'])
|
||||||
q_input = question
|
q_input = question
|
||||||
llm_input = random.choice(self.instruction_pool).format(question)
|
llm_input = random.choice(self.instruction_pool).format(question, data["choices"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": data['image'],
|
"image": data['image'],
|
||||||
@ -104,25 +98,103 @@ class AOKVQADataset(VQADataset, __DisplMixin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AOKVQGDataset(AOKVQADataset):
|
class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin):
|
||||||
|
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
"""
|
||||||
self.instruction_pool = [
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
'Given the image, generate a question whose answer is: {}',
|
ann_root (string): directory to store the annotation file
|
||||||
'Based on the image, provide a question with the answer: {}',
|
"""
|
||||||
'Given the visual representation, create a question for which the answer is "{}"',
|
|
||||||
'From the image provided, craft a question that leads to the reply: {}',
|
|
||||||
'Considering the picture, come up with a question where the answer is: {}',
|
|
||||||
'Taking the image into account, generate an question that has the answer: {}'
|
|
||||||
]
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
self.vis_root = vis_root
|
||||||
data = self.get_data(index)
|
|
||||||
instruction = random.choice(self.instruction_pool).format(data['answer'])
|
self.annotation = json.load(open(ann_paths[0]))
|
||||||
|
|
||||||
|
self.instruction_pool =[
|
||||||
|
'{} Choose from {}.',
|
||||||
|
'Q: {} Multi Choices: {} A: ',
|
||||||
|
'Question: {} Multi Choices: {} Answer: ',
|
||||||
|
"{} Choose one from the following possible answers: {}. ",
|
||||||
|
'{} Choose from {}. The answer is',
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.coco_fmt_qust_file = ann_paths[2]
|
||||||
|
self.coco_fmt_anno_file = ann_paths[3]
|
||||||
|
except IndexError:
|
||||||
|
self.coco_fmt_qust_file = None
|
||||||
|
self.coco_fmt_anno_file = None
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.source = 'aokvqa'
|
||||||
|
|
||||||
|
def collater(self, samples):
|
||||||
|
(
|
||||||
|
image_list,
|
||||||
|
question_list,
|
||||||
|
question_id_list,
|
||||||
|
choices_list,
|
||||||
|
correct_choice_idx_list,
|
||||||
|
direct_answers_list,
|
||||||
|
llm_input_list,
|
||||||
|
q_input_list,
|
||||||
|
source_list,
|
||||||
|
) = ([], [], [], [], [], [], [], [], [])
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
image_list.append(sample["image"])
|
||||||
|
question_list.append(sample["text_input"])
|
||||||
|
question_id_list.append(sample["question_id"])
|
||||||
|
choices_list.append(sample["choices"])
|
||||||
|
correct_choice_idx_list.append(sample["correct_choice_idx"])
|
||||||
|
direct_answers_list.append(sample["direct_answers"])
|
||||||
|
llm_input_list.append(sample["llm_input"])
|
||||||
|
q_input_list.append(sample["q_input"])
|
||||||
|
source_list.append(sample["source"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": data['image'],
|
"image": torch.stack(image_list, dim=0),
|
||||||
"instruction_input": instruction,
|
"text_input": question_list,
|
||||||
"answer": data['question'],
|
"question_id": question_id_list,
|
||||||
|
"choices": choices_list,
|
||||||
|
"correct_choice_idx": correct_choice_idx_list,
|
||||||
|
"direct_answers": direct_answers_list,
|
||||||
|
"llm_input": llm_input_list,
|
||||||
|
"q_input": q_input_list,
|
||||||
|
"source": source_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
ann = self.annotation[index]
|
||||||
|
|
||||||
|
image_path = os.path.join(self.vis_root, ann["image"])
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
question = self.text_processor(ann["question"])
|
||||||
|
|
||||||
|
choices = ann["choices"]
|
||||||
|
if "correct_choice_idx" in ann:
|
||||||
|
correct_choice_idx = ann["correct_choice_idx"]
|
||||||
|
else:
|
||||||
|
correct_choice_idx = None
|
||||||
|
|
||||||
|
if "direct_answers" in ann:
|
||||||
|
direct_answers = ann["direct_answers"]
|
||||||
|
else:
|
||||||
|
direct_answers = None
|
||||||
|
|
||||||
|
llm_input = random.choice(self.instruction_pool).format(question, ", ".join(choices))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"q_input": question,
|
||||||
|
"llm_input": llm_input,
|
||||||
|
"text_input": question,
|
||||||
|
"question_id": ann["question_id"],
|
||||||
|
"choices": choices,
|
||||||
|
"correct_choice_idx": correct_choice_idx,
|
||||||
|
"direct_answers": direct_answers,
|
||||||
|
"source": 'aokvqa',
|
||||||
|
}
|
||||||
|
|
@ -59,83 +59,7 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
|||||||
"text_input": caption,
|
"text_input": caption,
|
||||||
"image_id": self.img_ids[ann["image_id"]],
|
"image_id": self.img_ids[ann["image_id"]],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class COCOCaptionDataset(BaseDataset, __DisplMixin):
|
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
|
||||||
"""
|
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
||||||
ann_root (string): directory to store the annotation file
|
|
||||||
"""
|
|
||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
|
||||||
|
|
||||||
self.img_ids = {}
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
self.filter_anntation = []
|
|
||||||
|
|
||||||
for ann in self.annotation:
|
|
||||||
if "train" in ann["image"]:
|
|
||||||
self.filter_anntation.append(ann)
|
|
||||||
self.annotation = self.filter_anntation
|
|
||||||
|
|
||||||
for ann in self.annotation:
|
|
||||||
img_id = ann["image_id"]
|
|
||||||
if img_id not in self.img_ids.keys():
|
|
||||||
self.img_ids[img_id] = n
|
|
||||||
n += 1
|
|
||||||
|
|
||||||
self.instruction_pool = [
|
|
||||||
'Briefly describe this image.',
|
|
||||||
'Provide a concise depiction of this image.',
|
|
||||||
'Present a short description of this image.',
|
|
||||||
'Summarize this image in a few words.',
|
|
||||||
'A short image caption:',
|
|
||||||
'A short image description:',
|
|
||||||
'A photo of ',
|
|
||||||
'An image that shows ',
|
|
||||||
'Write a short description for the image. ',
|
|
||||||
'Write a description for the photo.',
|
|
||||||
'Provide a description of what is presented in the photo.',
|
|
||||||
'Briefly describe the content of the image.',
|
|
||||||
'Can you briefly explain what you see in the image?',
|
|
||||||
'Could you use a few words to describe what you perceive in the photo?',
|
|
||||||
'Please provide a short depiction of the picture.',
|
|
||||||
'Using language, provide a short account of the image.',
|
|
||||||
'Use a few words to illustrate what is happening in the picture.',
|
|
||||||
]
|
|
||||||
self.source = 'coco_cap'
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
|
|
||||||
# TODO this assumes image input, not general enough
|
|
||||||
ann = self.annotation[index]
|
|
||||||
|
|
||||||
# img_file = ann["image"].split("/")[-1]
|
|
||||||
img_file = ann["image"]
|
|
||||||
image_path = os.path.join(self.vis_root, img_file)
|
|
||||||
image = Image.open(image_path).convert("RGB")
|
|
||||||
|
|
||||||
image = self.vis_processor(image)
|
|
||||||
caption = self.text_processor(ann["caption"])
|
|
||||||
|
|
||||||
# instruction = random.choice(self.instruction_pool)
|
|
||||||
# instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
|
|
||||||
q_input = ""
|
|
||||||
llm_input = random.choice(self.instruction_pool)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"image": image,
|
|
||||||
"image_id": ann["image"],
|
|
||||||
"answer": caption,
|
|
||||||
"q_input": q_input,
|
|
||||||
"llm_input": llm_input,
|
|
||||||
"text_input": llm_input,
|
|
||||||
"text_output": caption,
|
|
||||||
"source": 'coco_cap',
|
|
||||||
}
|
|
||||||
|
|
||||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||||
"""
|
"""
|
||||||
@ -151,7 +75,7 @@ class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
|||||||
|
|
||||||
image_path = os.path.join(self.vis_root, ann["image"])
|
image_path = os.path.join(self.vis_root, ann["image"])
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
image = self.vis_processor(image)
|
image = self.vis_processor(image)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -159,3 +83,4 @@ class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
|||||||
"image_id": ann["image_id"],
|
"image_id": ann["image_id"],
|
||||||
"instance_id": ann["instance_id"],
|
"instance_id": ann["instance_id"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,18 +9,102 @@ import os
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
|
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset
|
||||||
|
|
||||||
COCOCapDataset = COCOCaptionDataset
|
class __DisplMixin:
|
||||||
|
def displ_item(self, index):
|
||||||
|
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||||
|
|
||||||
|
return OrderedDict(
|
||||||
|
{
|
||||||
|
"file": ann["image"],
|
||||||
|
"caption": ann["caption"],
|
||||||
|
"image": sample["image"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class COCOCapDataset(BaseDataset, __DisplMixin):
|
||||||
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
|
self.img_ids = {}
|
||||||
|
n = 0
|
||||||
|
|
||||||
|
self.filter_anntation = []
|
||||||
|
|
||||||
|
for ann in self.annotation:
|
||||||
|
if "train" in ann["image"]:
|
||||||
|
self.filter_anntation.append(ann)
|
||||||
|
self.annotation = self.filter_anntation
|
||||||
|
|
||||||
|
for ann in self.annotation:
|
||||||
|
img_id = ann["image_id"]
|
||||||
|
if img_id not in self.img_ids.keys():
|
||||||
|
self.img_ids[img_id] = n
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
self.instruction_pool = [
|
||||||
|
'Briefly describe this image.',
|
||||||
|
'Provide a concise depiction of this image.',
|
||||||
|
'Present a short description of this image.',
|
||||||
|
'Summarize this image in a few words.',
|
||||||
|
'A short image caption:',
|
||||||
|
'A short image description:',
|
||||||
|
'A photo of ',
|
||||||
|
'An image that shows ',
|
||||||
|
'Write a short description for the image. ',
|
||||||
|
'Write a description for the photo.',
|
||||||
|
'Provide a description of what is presented in the photo.',
|
||||||
|
'Briefly describe the content of the image.',
|
||||||
|
'Can you briefly explain what you see in the image?',
|
||||||
|
'Could you use a few words to describe what you perceive in the photo?',
|
||||||
|
'Please provide a short depiction of the picture.',
|
||||||
|
'Using language, provide a short account of the image.',
|
||||||
|
'Use a few words to illustrate what is happening in the picture.',
|
||||||
|
]
|
||||||
|
self.source = 'coco_cap'
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
# TODO this assumes image input, not general enough
|
||||||
|
ann = self.annotation[index]
|
||||||
|
|
||||||
|
# img_file = ann["image"].split("/")[-1]
|
||||||
|
img_file = ann["image"]
|
||||||
|
image_path = os.path.join(self.vis_root, img_file)
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
caption = self.text_processor(ann["caption"])
|
||||||
|
|
||||||
|
instruction = random.choice(self.instruction_pool)
|
||||||
|
# q_input = ""
|
||||||
|
q_input = instruction
|
||||||
|
llm_input = instruction
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"image_id": ann["image"],
|
||||||
|
"answer": caption,
|
||||||
|
"q_input": q_input,
|
||||||
|
"llm_input": llm_input,
|
||||||
|
"text_input": llm_input,
|
||||||
|
"text_output": caption,
|
||||||
|
"source": 'coco_cap',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class COCOCapEvalDataset(CaptionEvalDataset):
|
class COCOCapEvalDataset(CaptionEvalDataset):
|
||||||
@ -31,6 +115,26 @@ class COCOCapEvalDataset(CaptionEvalDataset):
|
|||||||
split (string): val or test
|
split (string): val or test
|
||||||
"""
|
"""
|
||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
|
self.instruction_pool = [
|
||||||
|
'Briefly describe this image.',
|
||||||
|
'Provide a concise depiction of this image.',
|
||||||
|
'Present a short description of this image.',
|
||||||
|
'Summarize this image in a few words.',
|
||||||
|
'A short image caption:',
|
||||||
|
'A short image description:',
|
||||||
|
'A photo of ',
|
||||||
|
'An image that shows ',
|
||||||
|
'Write a short description for the image. ',
|
||||||
|
'Write a description for the photo.',
|
||||||
|
'Provide a description of what is presented in the photo.',
|
||||||
|
'Briefly describe the content of the image.',
|
||||||
|
'Can you briefly explain what you see in the image?',
|
||||||
|
'Could you use a few words to describe what you perceive in the photo?',
|
||||||
|
'Please provide a short depiction of the picture.',
|
||||||
|
'Using language, provide a short account of the image.',
|
||||||
|
'Use a few words to illustrate what is happening in the picture.',
|
||||||
|
]
|
||||||
self.source = 'coco_cap'
|
self.source = 'coco_cap'
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@ -38,15 +142,25 @@ class COCOCapEvalDataset(CaptionEvalDataset):
|
|||||||
|
|
||||||
image_path = os.path.join(self.vis_root, ann["image"])
|
image_path = os.path.join(self.vis_root, ann["image"])
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
try:
|
||||||
image = self.vis_processor(image)
|
image = self.vis_processor(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print(image_path)
|
||||||
|
|
||||||
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
||||||
|
instruction = random.choice(self.instruction_pool)
|
||||||
|
# q_input = ""
|
||||||
|
q_input = instruction
|
||||||
|
llm_input = instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
||||||
"image_id": img_id,
|
"image_id": img_id,
|
||||||
"instance_id": ann["instance_id"],
|
"text_input":llm_input,
|
||||||
|
"q_input": q_input,
|
||||||
|
"llm_input": llm_input,
|
||||||
|
"source": self.source,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +149,6 @@ class OKVQAEvalDataset(VQAEvalDataset, __DisplMixin):
|
|||||||
|
|
||||||
self.source = 'okvqa'
|
self.source = 'okvqa'
|
||||||
self.annotation_add = self.get_data()
|
self.annotation_add = self.get_data()
|
||||||
self._add_instance_ids()
|
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
ann_instruct = list()
|
ann_instruct = list()
|
||||||
@ -180,7 +179,6 @@ class OKVQAEvalDataset(VQAEvalDataset, __DisplMixin):
|
|||||||
"image_id": ann["image"],
|
"image_id": ann["image"],
|
||||||
'image_path': image_path,
|
'image_path': image_path,
|
||||||
"question_id": ann["question_id"],
|
"question_id": ann["question_id"],
|
||||||
# "instance_id": ann["instance_id"],
|
|
||||||
"question": question,
|
"question": question,
|
||||||
"q_input": q_input,
|
"q_input": q_input,
|
||||||
"llm_input": llm_input,
|
"llm_input": llm_input,
|
||||||
|
@ -45,7 +45,6 @@ from transformers.utils import logging
|
|||||||
from transformers.models.bert.configuration_bert import BertConfig
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
from minigpt4.models.moe.utils import (
|
from minigpt4.models.moe.utils import (
|
||||||
FeedForward,
|
|
||||||
MoEModelOutput,
|
MoEModelOutput,
|
||||||
MoEModelOutputWithPooling,
|
MoEModelOutputWithPooling,
|
||||||
use_experts,
|
use_experts,
|
||||||
|
1276
minigpt4/models/QformerMoELN.py
Normal file
@ -389,17 +389,23 @@ class BertOutput(nn.Module): # Add & Norm
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
|
# remove LayerNorm
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
nn.Module.__init__(self)
|
super().__init__()
|
||||||
# first layer
|
self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.intermediate_query = BertIntermediate(config)
|
if isinstance(config.hidden_act, str):
|
||||||
# second layer
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
self.output_query = BertOutput(config)
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2
|
||||||
|
# self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2
|
||||||
|
|
||||||
def forward(self, hidden_states: Tensor):
|
def forward(self, hidden_states: Tensor):
|
||||||
input_tensor = hidden_states
|
hidden_states = self.dense1(hidden_states)
|
||||||
intermediate_output = self.intermediate_query(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
hidden_states = self.output_query(intermediate_output, input_tensor)
|
hidden_states = self.dense2(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -433,7 +439,6 @@ class BertLayer(nn.Module):
|
|||||||
self.layer_judge = moe_layer_judge(layer_num)
|
self.layer_judge = moe_layer_judge(layer_num)
|
||||||
self.num_beams = config.moebert_num_beams
|
self.num_beams = config.moebert_num_beams
|
||||||
ffn = FeedForward(config)
|
ffn = FeedForward(config)
|
||||||
|
|
||||||
if self.use_experts:
|
if self.use_experts:
|
||||||
self.experts = RouteMoELayer(
|
self.experts = RouteMoELayer(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@ -446,8 +451,7 @@ class BertLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.experts = ffn
|
self.experts = ffn
|
||||||
|
self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -538,7 +542,7 @@ class BertLayer(nn.Module):
|
|||||||
if self.layer_judge == 'first' and self.num_beams>1:
|
if self.layer_judge == 'first' and self.num_beams>1:
|
||||||
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
|
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
|
||||||
# adjust the dimension of layer_output_text to bz*num_beams
|
# adjust the dimension of layer_output_text to bz*num_beams
|
||||||
layer_output_text = self.adjust_layer_output_text(layer_output_text)
|
layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text)
|
||||||
|
|
||||||
if self.layer_judge == 'mid' and self.num_beams > 1:
|
if self.layer_judge == 'mid' and self.num_beams > 1:
|
||||||
# layer_output_text [bz*num_beams, len, hidden_size]
|
# layer_output_text [bz*num_beams, len, hidden_size]
|
||||||
@ -575,11 +579,11 @@ class BertLayer(nn.Module):
|
|||||||
attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len])
|
attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len])
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def adjust_layer_output_text(self, layer_output_text):
|
def adjust_hidden_states_by_num_beams(self, hidden_states):
|
||||||
batch_size, text_length, hidden_size = layer_output_text.shape
|
batch_size, text_length, hidden_size = hidden_states.shape
|
||||||
tmp_text = layer_output_text.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size)
|
tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size)
|
||||||
layer_output_text = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768]
|
hidden_states = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768]
|
||||||
return layer_output_text
|
return hidden_states
|
||||||
|
|
||||||
def route_moe_last_layer_top1(self, layer_output, layer_output_text):
|
def route_moe_last_layer_top1(self, layer_output, layer_output_text):
|
||||||
batch_size = layer_output[0].shape[0]
|
batch_size = layer_output[0].shape[0]
|
||||||
@ -602,20 +606,21 @@ class BertLayer(nn.Module):
|
|||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
|
||||||
return layer_output
|
return layer_output
|
||||||
|
|
||||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
|
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
|
||||||
if not self.use_experts:
|
if not self.use_experts:
|
||||||
layer_output = self.experts(attention_output)
|
hidden_states = self.experts(attention_output)
|
||||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
layer_output = self.expert_ln(hidden_states + attention_output)
|
||||||
return layer_output, None, None, None, 0.0
|
return layer_output, None, None, None, 0.0
|
||||||
|
|
||||||
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
|
hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
|
||||||
attention_output, expert_attention_mask, beam_scores, expert_route
|
attention_output, expert_attention_mask, beam_scores, expert_route
|
||||||
)
|
)
|
||||||
|
if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1:
|
||||||
|
attention_output = self.adjust_hidden_states_by_num_beams(attention_output)
|
||||||
|
layer_output = self.expert_ln(hidden_states + attention_output)
|
||||||
|
|
||||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
|
||||||
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
|
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
@ -722,7 +727,7 @@ class BertEncoder(nn.Module):
|
|||||||
]
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return MoEModelOutput(
|
return MoEModelOutput(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_decoder_cache,
|
past_key_values=next_decoder_cache,
|
||||||
|
1367
minigpt4/models/QformerRouteMoELN.py
Normal file
@ -22,6 +22,7 @@ from minigpt4.common.logger import MetricLogger
|
|||||||
from minigpt4.models.base_model import BaseModel
|
from minigpt4.models.base_model import BaseModel
|
||||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
from minigpt4.models.QformerMoE import BertMoELMHeadModel
|
from minigpt4.models.QformerMoE import BertMoELMHeadModel
|
||||||
|
from minigpt4.models.QformerMoELN import BertMoELMHeadModelLNIn
|
||||||
from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel
|
from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel
|
||||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
@ -88,7 +89,7 @@ class Blip2Base(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_QformerMoE(cls, num_query_token, vision_width, moebert_expert_num, moebert_route_method, moebert_load_balance, moe_topk=1, use_balance_loss=True, moe_weight_type='l2_norm', cross_attention_freq=2):
|
def init_QformerMoE(cls, num_query_token, vision_width, moebert_expert_num, moebert_route_method, moebert_load_balance, moe_topk=1, use_balance_loss=True, moe_weight_type='l2_norm', cross_attention_freq=2,ln_position="out"):
|
||||||
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||||
|
|
||||||
moe_encoder_config.encoder_width = vision_width
|
moe_encoder_config.encoder_width = vision_width
|
||||||
@ -104,9 +105,14 @@ class Blip2Base(BaseModel):
|
|||||||
moe_encoder_config.use_balance_loss = use_balance_loss
|
moe_encoder_config.use_balance_loss = use_balance_loss
|
||||||
moe_encoder_config.moe_weight_type = moe_weight_type
|
moe_encoder_config.moe_weight_type = moe_weight_type
|
||||||
|
|
||||||
MoEQformer = BertMoELMHeadModel.from_pretrained(
|
if ln_position == "out":
|
||||||
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
MoEQformer = BertMoELMHeadModel.from_pretrained(
|
||||||
)
|
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
||||||
|
)
|
||||||
|
elif ln_position == "in":
|
||||||
|
MoEQformer = BertMoELMHeadModelLNIn.from_pretrained(
|
||||||
|
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
||||||
|
)
|
||||||
query_tokens = nn.Parameter(
|
query_tokens = nn.Parameter(
|
||||||
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
|
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
|
||||||
)
|
)
|
||||||
|
@ -65,6 +65,8 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
use_balance_loss = True,
|
use_balance_loss = True,
|
||||||
moe_weight_type = "l2_norm",
|
moe_weight_type = "l2_norm",
|
||||||
gate_save_path = None,
|
gate_save_path = None,
|
||||||
|
bal_loss_decay_epoch = 3,
|
||||||
|
ln_position = "out",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
transformers_version = version.parse(transformers.__version__)
|
transformers_version = version.parse(transformers.__version__)
|
||||||
@ -112,7 +114,8 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
moe_topk=moe_topk,
|
moe_topk=moe_topk,
|
||||||
use_balance_loss=use_balance_loss,
|
use_balance_loss=use_balance_loss,
|
||||||
moe_weight_type=moe_weight_type,
|
moe_weight_type=moe_weight_type,
|
||||||
cross_attention_freq=2
|
cross_attention_freq=2,
|
||||||
|
ln_position=ln_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||||
@ -221,6 +224,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
self.moebert_num_beams = moebert_num_beams
|
self.moebert_num_beams = moebert_num_beams
|
||||||
|
|
||||||
self.gate_save_path = gate_save_path
|
self.gate_save_path = gate_save_path
|
||||||
|
self.bal_loss_decay_epoch = bal_loss_decay_epoch
|
||||||
# if self.gate_save_path != None:
|
# if self.gate_save_path != None:
|
||||||
# import os
|
# import os
|
||||||
# if not os.path.exists(self.gate_save_path):
|
# if not os.path.exists(self.gate_save_path):
|
||||||
@ -392,9 +396,12 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
return_dict=True,
|
return_dict=True,
|
||||||
labels=targets,
|
labels=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_moeqformer:
|
if self.use_moeqformer:
|
||||||
loss = outputs.loss + self.moebert_load_balance * gate_loss
|
if samples['epoch'] > self.bal_loss_decay_epoch:
|
||||||
|
loss = outputs.loss
|
||||||
|
else:
|
||||||
|
loss = outputs.loss + self.moebert_load_balance * gate_loss
|
||||||
else:
|
else:
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
@ -512,6 +519,16 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
|
|
||||||
with self.maybe_autocast():
|
with self.maybe_autocast():
|
||||||
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
|
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
|
||||||
|
|
||||||
|
# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/embedding/"
|
||||||
|
# np.save(os.join(path, "inputs_llm.npy"), inputs_llm.cpu().numpy)
|
||||||
|
# np.save(os.join(path, "inputs_llm.npy"), self.llm_model.get_input_embeddings().weight.cpu().numpy)
|
||||||
|
# samples_copy = samples.copy()
|
||||||
|
# samples_copy.pop('image', None)
|
||||||
|
# with open(os.path.join(path, 'test_samples.json'),'a+') as f:
|
||||||
|
# f.write(f"{json.dumps(samples_copy)}\n")
|
||||||
|
|
||||||
|
|
||||||
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
||||||
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
|
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
|
||||||
|
|
||||||
@ -654,6 +671,8 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
use_balance_loss = cfg.get("use_balance_loss", True)
|
use_balance_loss = cfg.get("use_balance_loss", True)
|
||||||
moe_weight_type = cfg.get("moe_weight_type",'l2_norm')
|
moe_weight_type = cfg.get("moe_weight_type",'l2_norm')
|
||||||
gate_save_path = cfg.get("gate_save_path", None)
|
gate_save_path = cfg.get("gate_save_path", None)
|
||||||
|
bal_loss_decay_epoch = cfg.get("bal_loss_decay_epoch", 3)
|
||||||
|
ln_position = cfg.get("ln_position","out")
|
||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
vit_model=vit_model,
|
vit_model=vit_model,
|
||||||
@ -683,6 +702,8 @@ class Blip2VicunaInstruct(Blip2Base):
|
|||||||
use_balance_loss=use_balance_loss,
|
use_balance_loss=use_balance_loss,
|
||||||
moe_weight_type=moe_weight_type,
|
moe_weight_type=moe_weight_type,
|
||||||
gate_save_path=gate_save_path,
|
gate_save_path=gate_save_path,
|
||||||
|
bal_loss_decay_epoch=bal_loss_decay_epoch,
|
||||||
|
ln_position=ln_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if qformer_text_input:
|
# if qformer_text_input:
|
||||||
|
@ -165,7 +165,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
self.route_method = route_method
|
self.route_method = route_method
|
||||||
if self.route_method == "pre-route":
|
if self.route_method == "pre-route":
|
||||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||||
elif self.route_method == "post-route":
|
elif self.route_method in ["post-route", "post-route-dp"]:
|
||||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||||
self.gate = gate
|
self.gate = gate
|
||||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||||
@ -252,6 +252,53 @@ class RouteMoELayer(nn.Module):
|
|||||||
|
|
||||||
return beam_scores, expert_route, beam_idx
|
return beam_scores, expert_route, beam_idx
|
||||||
|
|
||||||
|
def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||||
|
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route', 'post-route-dp']:
|
||||||
|
# current_scores_log torch.Size([bz, num_experts])
|
||||||
|
assert beam_scores==None and expert_route==None
|
||||||
|
current_scores = torch.exp(current_scores_log)
|
||||||
|
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||||
|
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||||
|
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||||
|
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_size = int(batch_size // self.num_beams)
|
||||||
|
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||||
|
next_scores_exp = torch.exp(next_scores_raw)
|
||||||
|
import pdb;pdb.set_trace()
|
||||||
|
|
||||||
|
next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True)
|
||||||
|
next_scores = next_scores_raw.view(batch_size, self.num_beams)
|
||||||
|
next_experts = next_experts_raw.view(batch_size, self.num_beams)
|
||||||
|
# next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价
|
||||||
|
# next_scores torch.Size([bz * num_beams, 1])
|
||||||
|
# next_tokens torch.Size([bz * num_beams, 1])
|
||||||
|
|
||||||
|
next_batch_beam = list()
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
next_sent_beam = list()
|
||||||
|
expert_id = next_experts[batch_idx]
|
||||||
|
expert_score = next_scores[batch_idx]
|
||||||
|
values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True)
|
||||||
|
for i in range(self.num_beams):
|
||||||
|
beam_id = index[i].item()
|
||||||
|
ex_id = expert_id[beam_id].item()
|
||||||
|
effective_beam_id = batch_idx*self.num_beams + beam_id
|
||||||
|
next_sent_beam.append((values[i], ex_id, effective_beam_id))
|
||||||
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
|
||||||
|
import pdb;pdb.set_trace()
|
||||||
|
|
||||||
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
|
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||||
|
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||||
|
pre_route = expert_route[beam_idx,:]
|
||||||
|
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||||
|
|
||||||
|
return beam_scores, expert_route, beam_idx
|
||||||
|
|
||||||
|
|
||||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||||
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
||||||
# current_scores_log torch.Size([bz, num_experts])
|
# current_scores_log torch.Size([bz, num_experts])
|
||||||
@ -267,6 +314,8 @@ class RouteMoELayer(nn.Module):
|
|||||||
batch_size = int(batch_size // self.num_beams)
|
batch_size = int(batch_size // self.num_beams)
|
||||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||||
next_scores_exp = torch.exp(next_scores_raw)
|
next_scores_exp = torch.exp(next_scores_raw)
|
||||||
|
import pdb;pdb.set_trace()
|
||||||
|
|
||||||
next_scores_raw1 = next_scores_exp.view(
|
next_scores_raw1 = next_scores_exp.view(
|
||||||
batch_size, self.num_beams * self.num_experts
|
batch_size, self.num_beams * self.num_experts
|
||||||
) # torch.Size([bz, num_beams*num_experts])
|
) # torch.Size([bz, num_beams*num_experts])
|
||||||
@ -289,7 +338,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
||||||
next_batch_beam.extend(next_sent_beam)
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
|
||||||
# import pdb;pdb.set_trace()
|
import pdb;pdb.set_trace()
|
||||||
|
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||||
@ -301,8 +350,6 @@ class RouteMoELayer(nn.Module):
|
|||||||
|
|
||||||
return beam_scores, expert_route, beam_idx
|
return beam_scores, expert_route, beam_idx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward_expert_ffn(self, x, expert_select, current_scores):
|
def forward_expert_ffn(self, x, expert_select, current_scores):
|
||||||
"""
|
"""
|
||||||
x_repeat : [bz*num_beams, 32,768]
|
x_repeat : [bz*num_beams, 32,768]
|
||||||
@ -343,6 +390,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
|
|
||||||
batch_size, num_tokens = x.shape[0], x.shape[1]
|
batch_size, num_tokens = x.shape[0], x.shape[1]
|
||||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||||
|
|
||||||
current_expert_select = expert_route[:,-1]
|
current_expert_select = expert_route[:,-1]
|
||||||
|
|
||||||
import pdb;pdb.set_trace()
|
import pdb;pdb.set_trace()
|
||||||
@ -368,7 +416,6 @@ class RouteMoELayer(nn.Module):
|
|||||||
output_x = self.experts[expert_idx].forward(input_x)
|
output_x = self.experts[expert_idx].forward(input_x)
|
||||||
return output_x
|
return output_x
|
||||||
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
outputs = list()
|
outputs = list()
|
||||||
logits_gate_lst = list()
|
logits_gate_lst = list()
|
||||||
for expert_idx in range(self.num_experts):
|
for expert_idx in range(self.num_experts):
|
||||||
@ -392,10 +439,14 @@ class RouteMoELayer(nn.Module):
|
|||||||
# importance loss
|
# importance loss
|
||||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
|
|
||||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
if self.route_method == 'post-route':
|
||||||
|
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||||
|
elif self.route_method == 'post-route-dp':
|
||||||
|
beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||||
|
|
||||||
# beam_scores torch.Size([bz*num_beam])
|
# beam_scores torch.Size([bz*num_beam])
|
||||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||||
current_select_expert = expert_route[:,-1]
|
current_select_expert = expert_route[:,-1]
|
||||||
@ -431,7 +482,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self.route_method == 'pre-route':
|
if self.route_method == 'pre-route':
|
||||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||||
elif self.route_method == "post-route":
|
elif self.route_method in ['post-route', 'post-route-dp']:
|
||||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||||
|
|
||||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||||
@ -467,10 +518,11 @@ if __name__ == '__main__':
|
|||||||
batch_size = 4
|
batch_size = 4
|
||||||
x = torch.randn(batch_size, 32, 768)
|
x = torch.randn(batch_size, 32, 768)
|
||||||
beam_scores, expert_route = None, None
|
beam_scores, expert_route = None, None
|
||||||
|
|
||||||
x1 = x
|
x1 = x
|
||||||
x2 = x
|
x2 = x
|
||||||
|
x3 = x
|
||||||
beam_scores1, expert_route1 = None, None
|
beam_scores1, expert_route1 = None, None
|
||||||
|
beam_scores2, expert_route2 = None, None
|
||||||
|
|
||||||
for layer_num in [6, 8, 10]:
|
for layer_num in [6, 8, 10]:
|
||||||
layer_judge = moe_layer_judge(layer_num)
|
layer_judge = moe_layer_judge(layer_num)
|
||||||
@ -494,25 +546,41 @@ if __name__ == '__main__':
|
|||||||
# print(importance_loss)
|
# print(importance_loss)
|
||||||
# x = hidden_states1
|
# x = hidden_states1
|
||||||
|
|
||||||
gate1 = nn.Linear(768, 1, bias=False).float()
|
# experts_post = RouteMoELayer(
|
||||||
|
# hidden_size=768,
|
||||||
|
# expert=ffn,
|
||||||
|
# num_experts=config.moebert_expert_num,
|
||||||
|
# num_beams=config.moebert_num_beams,
|
||||||
|
# layer_judge = layer_judge,
|
||||||
|
# route_method = "post-route",
|
||||||
|
# weight_type="ffn_prob"
|
||||||
|
# )
|
||||||
|
# layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
|
||||||
|
# hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output
|
||||||
|
|
||||||
|
# print(beam_scores1)
|
||||||
|
# print(expert_route1)
|
||||||
|
# print(beam_idx)
|
||||||
|
# print(importance_loss)
|
||||||
|
# x1 = hidden_states2
|
||||||
|
|
||||||
experts_post = RouteMoELayer(
|
experts_post = RouteMoELayer(
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
expert=ffn,
|
expert=ffn,
|
||||||
num_experts=config.moebert_expert_num,
|
num_experts=config.moebert_expert_num,
|
||||||
num_beams=config.moebert_num_beams,
|
num_beams=config.moebert_num_beams,
|
||||||
layer_judge = layer_judge,
|
layer_judge = layer_judge,
|
||||||
route_method = "post-route",
|
route_method = "post-route-dp",
|
||||||
weight_type="ffn_prob"
|
weight_type="ffn_prob"
|
||||||
)
|
)
|
||||||
layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
|
layer_output = experts_post(x2, None, beam_scores2, expert_route2, False)
|
||||||
hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output
|
hidden_states3, beam_scores2, expert_route2, beam_idx2, importance_loss2 = layer_output
|
||||||
|
|
||||||
print(beam_scores1)
|
|
||||||
print(expert_route1)
|
|
||||||
print(beam_idx)
|
|
||||||
print(importance_loss)
|
|
||||||
x1 = hidden_states2
|
|
||||||
|
|
||||||
|
print(beam_scores2)
|
||||||
|
print(expert_route2)
|
||||||
|
print(beam_idx2)
|
||||||
|
print(importance_loss2)
|
||||||
|
x2 = hidden_states3
|
||||||
|
|
||||||
# gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
|
# gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
|
||||||
# experts_moe = MoELayer(
|
# experts_moe = MoELayer(
|
||||||
@ -526,12 +594,12 @@ if __name__ == '__main__':
|
|||||||
# weight_type=config.moe_weight_type,
|
# weight_type=config.moe_weight_type,
|
||||||
# )
|
# )
|
||||||
# attn_mask = torch.ones([batch_size, 32])
|
# attn_mask = torch.ones([batch_size, 32])
|
||||||
# layer_output = experts_moe(x2, attn_mask)
|
# layer_output = experts_moe(x3, attn_mask)
|
||||||
# hidden_states3, select_prob_gate, gate_load,_ = layer_output
|
# hidden_states4, select_prob_gate, gate_load,_ = layer_output
|
||||||
|
|
||||||
# print(select_prob_gate)
|
# print(select_prob_gate)
|
||||||
# print(gate_load)
|
# print(gate_load)
|
||||||
# x2 = hidden_states3
|
# x3 = hidden_states4
|
||||||
|
|
||||||
print("------------------------------------")
|
print("------------------------------------")
|
||||||
import pdb; pdb.set_trace()
|
import pdb; pdb.set_trace()
|
||||||
|
@ -18,7 +18,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
self.route_method = route_method
|
self.route_method = route_method
|
||||||
if self.route_method == "pre-route":
|
if self.route_method == "pre-route":
|
||||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||||
elif self.route_method == "post-route":
|
elif self.route_method in ["post-route", "post-route-dp"]:
|
||||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||||
self.gate = gate
|
self.gate = gate
|
||||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||||
@ -47,26 +47,67 @@ class RouteMoELayer(nn.Module):
|
|||||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||||
return prob_gate
|
return prob_gate
|
||||||
|
|
||||||
|
def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
if self.layer_judge=='first' and self.route_method in ['post-route-dp']:
|
||||||
if self.layer_judge=='first' and self.route_method=='pre-route':
|
# current_scores_log torch.Size([bz, num_experts])
|
||||||
assert beam_scores==None and expert_route==None
|
assert beam_scores==None and expert_route==None
|
||||||
current_scores = torch.exp(current_scores_log)
|
current_scores = torch.exp(current_scores_log)
|
||||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.layer_judge=='first' and self.route_method == 'post-route':
|
batch_size = int(batch_size // self.num_beams)
|
||||||
batch_size = batch_size
|
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||||
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_experts])
|
next_scores_exp = torch.exp(next_scores_raw)
|
||||||
else:
|
|
||||||
batch_size = int(batch_size // self.num_beams)
|
next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True)
|
||||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
next_scores = next_scores_raw.view(batch_size, self.num_beams)
|
||||||
next_scores_exp = torch.exp(next_scores_raw)
|
next_experts = next_experts_raw.view(batch_size, self.num_beams)
|
||||||
next_scores_raw1 = next_scores_exp.view(
|
# next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价
|
||||||
batch_size, self.num_beams * self.num_experts
|
# next_scores torch.Size([bz * num_beams, 1])
|
||||||
) # torch.Size([bz, num_beams*num_experts])
|
# next_tokens torch.Size([bz * num_beams, 1])
|
||||||
|
|
||||||
|
next_batch_beam = list()
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
next_sent_beam = list()
|
||||||
|
expert_id = next_experts[batch_idx]
|
||||||
|
expert_score = next_scores[batch_idx]
|
||||||
|
values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True)
|
||||||
|
for i in range(self.num_beams):
|
||||||
|
beam_id = index[i].item()
|
||||||
|
ex_id = expert_id[beam_id].item()
|
||||||
|
effective_beam_id = batch_idx*self.num_beams + beam_id
|
||||||
|
next_sent_beam.append((values[i], ex_id, effective_beam_id))
|
||||||
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
|
||||||
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
|
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||||
|
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||||
|
pre_route = expert_route[beam_idx,:]
|
||||||
|
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||||
|
|
||||||
|
return beam_scores, expert_route, beam_idx
|
||||||
|
|
||||||
|
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||||
|
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
||||||
|
# current_scores_log torch.Size([bz, num_experts])
|
||||||
|
assert beam_scores==None and expert_route==None
|
||||||
|
current_scores = torch.exp(current_scores_log)
|
||||||
|
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||||
|
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||||
|
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||||
|
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_size = int(batch_size // self.num_beams)
|
||||||
|
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||||
|
next_scores_exp = torch.exp(next_scores_raw)
|
||||||
|
|
||||||
|
next_scores_raw1 = next_scores_exp.view(
|
||||||
|
batch_size, self.num_beams * self.num_experts
|
||||||
|
) # torch.Size([bz, num_beams*num_experts])
|
||||||
|
|
||||||
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
|
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
|
||||||
# next_scores torch.Size([bz, num_beams])
|
# next_scores torch.Size([bz, num_beams])
|
||||||
@ -86,19 +127,11 @@ class RouteMoELayer(nn.Module):
|
|||||||
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
||||||
next_batch_beam.extend(next_sent_beam)
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
|
||||||
if self.layer_judge=='first' and self.route_method == 'post-route':
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||||
expert_route = next_experts.view(self.num_beams * batch_size)
|
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
pre_route = expert_route[beam_idx,:]
|
||||||
beam_experts = expert_route.new([x[1] for x in next_batch_beam]).unsqueeze(-1)
|
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||||
beam_idx = expert_route.new([int(x[2]/self.num_beams) for x in next_batch_beam])
|
|
||||||
expert_route = beam_experts
|
|
||||||
else:
|
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
|
||||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
|
||||||
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
|
||||||
pre_route = expert_route[beam_idx,:]
|
|
||||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
|
||||||
|
|
||||||
return beam_scores, expert_route, beam_idx
|
return beam_scores, expert_route, beam_idx
|
||||||
|
|
||||||
@ -153,7 +186,6 @@ class RouteMoELayer(nn.Module):
|
|||||||
# import pdb;pdb.set_trace()
|
# import pdb;pdb.set_trace()
|
||||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||||
|
|
||||||
|
|
||||||
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
|
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
|
||||||
|
|
||||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||||
@ -187,7 +219,12 @@ class RouteMoELayer(nn.Module):
|
|||||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||||
|
|
||||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
|
||||||
|
if self.route_method == 'post-route':
|
||||||
|
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||||
|
elif self.route_method == 'post-route-dp':
|
||||||
|
beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||||
|
|
||||||
# beam_scores torch.Size([bz*num_beam])
|
# beam_scores torch.Size([bz*num_beam])
|
||||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||||
current_select_expert = expert_route[:,-1]
|
current_select_expert = expert_route[:,-1]
|
||||||
@ -218,7 +255,7 @@ class RouteMoELayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self.route_method == 'pre-route':
|
if self.route_method == 'pre-route':
|
||||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||||
elif self.route_method == "post-route":
|
elif self.route_method in ['post-route', 'post-route-dp']:
|
||||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||||
|
|
||||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||||
|
@ -13,7 +13,7 @@ from typing import Optional, Tuple, List
|
|||||||
def use_experts(layer_idx):
|
def use_experts(layer_idx):
|
||||||
# if layer_idx % 2 == 0:
|
# if layer_idx % 2 == 0:
|
||||||
# use moe_ffn after cross_attns
|
# use moe_ffn after cross_attns
|
||||||
if int(layer_idx) in [6,8,10]:
|
if int(layer_idx) in [6,7,8,9,10,11]:
|
||||||
# layer 6/8/10
|
# layer 6/8/10
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
model:
|
||||||
|
arch: blip2_vicuna_instruct
|
||||||
|
model_type: vicuna7b_pretrain
|
||||||
|
load_pretrained: True
|
||||||
|
load_finetuned: True
|
||||||
|
vit_model: eva_clip_g
|
||||||
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_1610k_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0301/20240301223/checkpoint_best.pth"
|
||||||
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
num_query_token: 32
|
||||||
|
qformer_text_input: True
|
||||||
|
|
||||||
|
# T5
|
||||||
|
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||||
|
prompt: ""
|
||||||
|
max_txt_len: 256
|
||||||
|
max_output_txt_len: 256
|
||||||
|
|
||||||
|
# freeze
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_llm: True
|
||||||
|
freeze_qformer: False
|
||||||
|
freeze_t5_proj: False
|
||||||
|
|
||||||
|
# moe
|
||||||
|
use_moeqformer: True
|
||||||
|
use_route_moe: False
|
||||||
|
moebert_expert_num: 3
|
||||||
|
moebert_route_method: "gate-sentence-post"
|
||||||
|
moe_weight_type: "raw_prob"
|
||||||
|
moebert_load_balance: 0.05
|
||||||
|
moe_topk: 1
|
||||||
|
use_balance_loss: False
|
||||||
|
ln_position: "out"
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
gqa:
|
||||||
|
type: balanced_sft_raw_eval
|
||||||
|
batch_size: 4
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
type: ok_vqa_eval
|
||||||
|
batch_size: 4
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
coco_vqa: # 658104
|
||||||
|
type: vqa_v2_eval
|
||||||
|
batch_size: 4
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
aok_vqa: # train: 17056, val: 1145
|
||||||
|
batch_size: 4
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: instruction_tuning
|
||||||
|
seed: 42
|
||||||
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/eval/mix_coco_gqa_1610k_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0301/"
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: True
|
||||||
|
test_splits: ["val"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -10,7 +10,7 @@ model:
|
|||||||
load_finetuned: True
|
load_finetuned: True
|
||||||
vit_model: eva_clip_g
|
vit_model: eva_clip_g
|
||||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/20240112212/checkpoint_best.pth"
|
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_001_5e5lr_top6layer_textinqf_epo8_0128/20240128142/checkpoint_best.pth"
|
||||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
# vit encoder
|
# vit encoder
|
||||||
@ -39,27 +39,18 @@ model:
|
|||||||
use_moeqformer: True
|
use_moeqformer: True
|
||||||
use_route_moe: True
|
use_route_moe: True
|
||||||
moebert_route_method: "post-route"
|
moebert_route_method: "post-route"
|
||||||
moebert_load_balance: 0
|
moebert_load_balance: 0.01
|
||||||
moebert_expert_num: 2
|
moebert_expert_num: 2
|
||||||
moebert_num_beams: 2
|
moebert_num_beams: 2
|
||||||
moe_weight_type: 'ffn_prob'
|
moe_weight_type: 'ffn_prob'
|
||||||
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
|
use_balance_loss: False
|
||||||
|
bal_loss_decay_epoch: 8
|
||||||
|
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_001_5e5lr_top6layer_textinqf_epo8_0128/"
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
gqa:
|
gqa:
|
||||||
type: balanced_sft_raw_eval
|
type: balanced_sft_raw_eval
|
||||||
batch_size: 32
|
batch_size: 64
|
||||||
vis_processor:
|
|
||||||
eval:
|
|
||||||
name: "blip2_image_eval"
|
|
||||||
image_size: 224
|
|
||||||
text_processor:
|
|
||||||
eval:
|
|
||||||
name: "blip_caption"
|
|
||||||
|
|
||||||
ok_vqa: # train, valid (9009, 5046)
|
|
||||||
type: ok_vqa_eval
|
|
||||||
batch_size: 32
|
|
||||||
vis_processor:
|
vis_processor:
|
||||||
eval:
|
eval:
|
||||||
name: "blip2_image_eval"
|
name: "blip2_image_eval"
|
||||||
@ -70,6 +61,17 @@ datasets:
|
|||||||
|
|
||||||
coco_vqa: # 658104
|
coco_vqa: # 658104
|
||||||
type: vqa_v2_eval
|
type: vqa_v2_eval
|
||||||
|
batch_size: 64
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
coco_caption: # 414113 train
|
||||||
|
type: coco_cap_eval
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
eval:
|
eval:
|
||||||
@ -78,7 +80,18 @@ datasets:
|
|||||||
text_processor:
|
text_processor:
|
||||||
eval:
|
eval:
|
||||||
name: "blip_caption"
|
name: "blip_caption"
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
type: ok_vqa_eval
|
||||||
|
batch_size: 64
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
run:
|
run:
|
||||||
task: instruction_tuning
|
task: instruction_tuning
|
||||||
# optimizer
|
# optimizer
|
||||||
@ -96,7 +109,7 @@ run:
|
|||||||
iters_per_epoch: 3000
|
iters_per_epoch: 3000
|
||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_001_5e5lr_top6layer_textinqf_epo8_0128/"
|
||||||
|
|
||||||
amp: True
|
amp: True
|
||||||
resume_ckpt_path: null
|
resume_ckpt_path: null
|
||||||
|
@ -38,17 +38,17 @@ model:
|
|||||||
# moe
|
# moe
|
||||||
use_moeqformer: True
|
use_moeqformer: True
|
||||||
use_route_moe: True
|
use_route_moe: True
|
||||||
moebert_route_method: "post-route"
|
moebert_route_method: "post-route-dp"
|
||||||
moebert_load_balance: 0
|
moebert_load_balance: 0.05
|
||||||
moebert_expert_num: 3
|
moebert_expert_num: 2
|
||||||
moebert_num_beams: 3
|
moebert_num_beams: 2
|
||||||
moe_weight_type: 'ffn_prob'
|
moe_weight_type: 'ffn_prob'
|
||||||
use_balance_loss: False
|
use_balance_loss: False
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
gqa: # train: 943000, 12578, 12578)
|
gqa: # train: 943000, 12578, 12578)
|
||||||
type: balanced_sft_raw
|
type: balanced_sft_raw
|
||||||
batch_size: 16
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -64,7 +64,7 @@ datasets:
|
|||||||
sample_ratio: 10
|
sample_ratio: 10
|
||||||
|
|
||||||
ok_vqa: # train, valid (9009, 5046)
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
batch_size: 16
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -80,7 +80,7 @@ datasets:
|
|||||||
sample_ratio: 1
|
sample_ratio: 1
|
||||||
|
|
||||||
coco_vqa: # 658104
|
coco_vqa: # 658104
|
||||||
batch_size: 16
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -112,7 +112,7 @@ run:
|
|||||||
iters_per_epoch: 5000
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_3ex_3beam_1loss_5e5lr_top6layer_textinqf_epo8_0117/"
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_DP_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_005_5e5lr_top6layer_textinqf_epo8_0121/"
|
||||||
|
|
||||||
amp: True
|
amp: True
|
||||||
resume_ckpt_path: null
|
resume_ckpt_path: null
|
||||||
|
@ -38,14 +38,17 @@ model:
|
|||||||
# moe
|
# moe
|
||||||
use_moeqformer: True
|
use_moeqformer: True
|
||||||
use_route_moe: True
|
use_route_moe: True
|
||||||
moebert_expert_num: 5
|
moebert_route_method: "post-route"
|
||||||
moebert_num_beams: 1
|
moebert_load_balance: 0
|
||||||
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
moebert_expert_num: 2
|
||||||
|
moebert_num_beams: 2
|
||||||
|
moe_weight_type: 'ffn_prob'
|
||||||
|
use_balance_loss: False
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
gqa: # train: 943000, 12578, 12578)
|
gqa: # train: 943000, 12578, 12578)
|
||||||
type: balanced_sft_raw
|
type: balanced_sft_raw
|
||||||
batch_size: 4
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -61,7 +64,7 @@ datasets:
|
|||||||
sample_ratio: 10
|
sample_ratio: 10
|
||||||
|
|
||||||
ok_vqa: # train, valid (9009, 5046)
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
batch_size: 4
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -77,7 +80,7 @@ datasets:
|
|||||||
sample_ratio: 1
|
sample_ratio: 1
|
||||||
|
|
||||||
coco_vqa: # 658104
|
coco_vqa: # 658104
|
||||||
batch_size: 4
|
batch_size: 32
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -96,20 +99,20 @@ run:
|
|||||||
task: instruction_tuning
|
task: instruction_tuning
|
||||||
# optimizer
|
# optimizer
|
||||||
lr_sched: "linear_warmup_cosine_lr"
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
init_lr: 2e-5
|
init_lr: 5e-5
|
||||||
min_lr: 1e-6
|
min_lr: 1e-6
|
||||||
warmup_lr: 1e-6
|
warmup_lr: 1e-6
|
||||||
log_freq: 5
|
log_freq: 5
|
||||||
save_freq: 1500
|
save_freq: 1500
|
||||||
|
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 6
|
max_epoch: 8
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
warmup_steps: 600
|
warmup_steps: 600
|
||||||
iters_per_epoch: 5000
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1212_Test/"
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_1loss_5e5lr_top6layer_textinqf_epo8_0123/"
|
||||||
|
|
||||||
amp: True
|
amp: True
|
||||||
resume_ckpt_path: null
|
resume_ckpt_path: null
|
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
model:
|
||||||
|
arch: blip2_vicuna_instruct
|
||||||
|
model_type: vicuna7b_pretrain
|
||||||
|
load_pretrained: True
|
||||||
|
load_finetuned: False
|
||||||
|
vit_model: eva_clip_g
|
||||||
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
# finetuned: ""
|
||||||
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
num_query_token: 32
|
||||||
|
qformer_text_input: True
|
||||||
|
|
||||||
|
# vicuna
|
||||||
|
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||||
|
prompt: ""
|
||||||
|
max_txt_len: 256
|
||||||
|
max_output_txt_len: 256
|
||||||
|
|
||||||
|
# freeze
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_llm: True
|
||||||
|
freeze_qformer: False
|
||||||
|
freeze_t5_proj: False
|
||||||
|
|
||||||
|
# moe
|
||||||
|
use_moeqformer: False
|
||||||
|
use_route_moe: False
|
||||||
|
moebert_route_method: "post-route"
|
||||||
|
moebert_load_balance: 0
|
||||||
|
moebert_expert_num: 2
|
||||||
|
moebert_num_beams: 2
|
||||||
|
moe_weight_type: 'ffn_prob'
|
||||||
|
use_balance_loss: False
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
gqa: # train: 943000, 12578, 12578)
|
||||||
|
type: balanced_sft_raw
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 10
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 1
|
||||||
|
|
||||||
|
coco_vqa: # 658104
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 9
|
||||||
|
|
||||||
|
coco_caption: # 414113 train
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 7
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: instruction_tuning
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 5e-5
|
||||||
|
min_lr: 1e-6
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
log_freq: 5
|
||||||
|
save_freq: 1500
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 8
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 600
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
# output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_005_5e5lr_top6layer_textinqf_epo8_0122/"
|
||||||
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Base_top6layer_textinqf_epo8_0124/"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
valid_splits: ["val"]
|
||||||
|
# test_splits: ["val"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
model:
|
||||||
|
arch: blip2_vicuna_instruct
|
||||||
|
model_type: vicuna7b_pretrain
|
||||||
|
load_pretrained: True
|
||||||
|
load_finetuned: False
|
||||||
|
vit_model: eva_clip_g
|
||||||
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
# finetuned: ""
|
||||||
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
num_query_token: 32
|
||||||
|
qformer_text_input: True
|
||||||
|
|
||||||
|
# vicuna
|
||||||
|
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||||
|
prompt: ""
|
||||||
|
max_txt_len: 256
|
||||||
|
max_output_txt_len: 256
|
||||||
|
|
||||||
|
# freeze
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_llm: True
|
||||||
|
freeze_qformer: False
|
||||||
|
freeze_t5_proj: False
|
||||||
|
|
||||||
|
# moe
|
||||||
|
use_moeqformer: True
|
||||||
|
use_route_moe: True
|
||||||
|
moebert_route_method: "post-route"
|
||||||
|
moebert_load_balance: 0.01
|
||||||
|
moebert_expert_num: 2
|
||||||
|
moebert_num_beams: 2
|
||||||
|
moe_weight_type: 'ffn_prob'
|
||||||
|
use_balance_loss: False
|
||||||
|
bal_loss_decay_epoch: 3
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
gqa: # train: 943000, 12578, 12578)
|
||||||
|
type: balanced_sft_raw
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 10
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 1
|
||||||
|
|
||||||
|
coco_vqa: # 658104
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 9
|
||||||
|
|
||||||
|
coco_caption: # 414113 train
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 7
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: instruction_tuning
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 5e-5
|
||||||
|
min_lr: 1e-6
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
log_freq: 5
|
||||||
|
save_freq: 1500
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 8
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 600
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_001_loss_decay_5e5lr_top6layer_textinqf_epo8_0129/"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
valid_splits: ["val"]
|
||||||
|
# test_splits: ["val"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
model:
|
||||||
|
arch: blip2_vicuna_instruct
|
||||||
|
model_type: vicuna7b_pretrain
|
||||||
|
load_pretrained: True
|
||||||
|
load_finetuned: False
|
||||||
|
vit_model: eva_clip_g
|
||||||
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
# finetuned: ""
|
||||||
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
num_query_token: 32
|
||||||
|
qformer_text_input: True
|
||||||
|
|
||||||
|
# vicuna
|
||||||
|
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||||
|
prompt: ""
|
||||||
|
max_txt_len: 256
|
||||||
|
max_output_txt_len: 256
|
||||||
|
|
||||||
|
# freeze
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_llm: True
|
||||||
|
freeze_qformer: False
|
||||||
|
freeze_t5_proj: False
|
||||||
|
|
||||||
|
# moe
|
||||||
|
use_moeqformer: True
|
||||||
|
use_route_moe: True
|
||||||
|
moebert_route_method: "post-route"
|
||||||
|
moebert_load_balance: 0.05
|
||||||
|
moebert_expert_num: 2
|
||||||
|
moebert_num_beams: 2
|
||||||
|
moe_weight_type: 'ffn_prob'
|
||||||
|
use_balance_loss: False
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
gqa:
|
||||||
|
type: balanced_sft_raw_eval
|
||||||
|
batch_size: 16
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 8
|
||||||
|
|
||||||
|
coco_vqa: # 658104
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 15
|
||||||
|
|
||||||
|
aok_vqa: # train: 17056, val: 1145
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 12
|
||||||
|
|
||||||
|
ocrvqa: # train 207572
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 30
|
||||||
|
|
||||||
|
llava_reason: # 76643
|
||||||
|
batch_size: 16
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 80
|
||||||
|
|
||||||
|
llava_conversation: # 56681
|
||||||
|
batch_size: 16
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 30
|
||||||
|
|
||||||
|
llava_detail: # 23240
|
||||||
|
batch_size: 16
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 20
|
||||||
|
|
||||||
|
coco_caption: # 414113 train
|
||||||
|
batch_size: 16
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 10
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: instruction_tuning
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 5e-5
|
||||||
|
min_lr: 1e-6
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
log_freq: 5
|
||||||
|
save_freq: 1500
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 8
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 600
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_2ex_2beam_2loss_5e5lr_top6layer_textinqf_epo8_0118/"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
valid_splits: ["val"]
|
||||||
|
# test_splits: ["val"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
# All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
|
||||||
|
model:
|
||||||
|
arch: blip2_vicuna_instruct
|
||||||
|
model_type: vicuna7b_pretrain
|
||||||
|
load_pretrained: True
|
||||||
|
load_finetuned: False
|
||||||
|
vit_model: eva_clip_g
|
||||||
|
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
# finetuned: ""
|
||||||
|
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
|
||||||
|
# Q-Former
|
||||||
|
num_query_token: 32
|
||||||
|
qformer_text_input: True
|
||||||
|
|
||||||
|
# vicuna
|
||||||
|
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||||
|
prompt: ""
|
||||||
|
max_txt_len: 256
|
||||||
|
max_output_txt_len: 256
|
||||||
|
|
||||||
|
# freeze
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_llm: True
|
||||||
|
freeze_qformer: False
|
||||||
|
freeze_t5_proj: False
|
||||||
|
|
||||||
|
# moe
|
||||||
|
use_moeqformer: True
|
||||||
|
use_route_moe: True
|
||||||
|
moebert_route_method: "post-route-dp"
|
||||||
|
moebert_load_balance: 0.05
|
||||||
|
moebert_expert_num: 2
|
||||||
|
moebert_num_beams: 2
|
||||||
|
moe_weight_type: 'ffn_prob'
|
||||||
|
use_balance_loss: False
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
gqa: # train: 943000, 12578, 12578)
|
||||||
|
type: balanced_sft_raw
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 10
|
||||||
|
|
||||||
|
ok_vqa: # train, valid (9009, 5046)
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 1
|
||||||
|
|
||||||
|
coco_vqa: # 658104
|
||||||
|
batch_size: 32
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 9
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: instruction_tuning
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 5e-5
|
||||||
|
min_lr: 1e-6
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
log_freq: 5
|
||||||
|
save_freq: 1500
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 8
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 600
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_DP_Route_Post_ffn_prob_linear_1gate_2ex_2beam_2loss_5e5lr_top6layer_textinqf_epo8_0118/"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
valid_splits: ["val"]
|
||||||
|
# test_splits: ["val"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -53,7 +53,7 @@ class InstructionTask(BaseTask):
|
|||||||
run_cfg = cfg.run_cfg
|
run_cfg = cfg.run_cfg
|
||||||
|
|
||||||
num_beams = run_cfg.get("num_beams", 3)
|
num_beams = run_cfg.get("num_beams", 3)
|
||||||
max_len = run_cfg.get("max_len", 20)
|
max_len = run_cfg.get("max_len", 30)
|
||||||
min_len = run_cfg.get("min_len", 1)
|
min_len = run_cfg.get("min_len", 1)
|
||||||
|
|
||||||
evaluate = run_cfg.get("evaluate", False)
|
evaluate = run_cfg.get("evaluate", False)
|
||||||
@ -112,22 +112,33 @@ class InstructionTask(BaseTask):
|
|||||||
)
|
)
|
||||||
pred_qa_pairs = []
|
pred_qa_pairs = []
|
||||||
|
|
||||||
question_id = samples["question_id"]
|
text_inputs = samples["text_input"]
|
||||||
question = samples["text_input"]
|
|
||||||
sources = samples["source"]
|
sources = samples["source"]
|
||||||
|
source = samples["source"][0]
|
||||||
|
|
||||||
|
if source in ['vqav2','okvqa','gqa']:
|
||||||
|
sample_ids = [int(sample_id.item()) for sample_id in samples["question_id"]]
|
||||||
|
elif source in ['aokvqa']:
|
||||||
|
sample_ids = [sample_id for sample_id in samples["question_id"]]
|
||||||
|
elif source in ['coco_cap']:
|
||||||
|
sample_ids = samples["image_id"]
|
||||||
|
|
||||||
# For GQA
|
# For GQA
|
||||||
full_answers = samples.get("fullAnswer", ["" for i in range(len(question_id))])
|
full_answers = samples.get("fullAnswer", ["" for i in range(len(sample_ids))])
|
||||||
gt_answers = samples.get("gt_answers", ["" for i in range(len(question_id))])
|
gt_answers = samples.get("gt_answers", ["" for i in range(len(sample_ids))])
|
||||||
|
|
||||||
for answer, ques_id, ques, full_answer, gt_answer, source in zip(answers, question_id, question, full_answers, gt_answers, sources):
|
# For AOKVQA
|
||||||
ques_id = int(ques_id.item())
|
choices = samples.get("choices", ["" for i in range(len(sample_ids))])
|
||||||
|
|
||||||
|
for answer, sample_id, text_input, full_answer, gt_answer, choice, source in zip(answers, sample_ids, text_inputs, full_answers, gt_answers, choices, sources):
|
||||||
pred_qa_pairs.append({
|
pred_qa_pairs.append({
|
||||||
"question_id": ques_id,
|
"question_id": sample_id,
|
||||||
"question": ques,
|
"question": text_input,
|
||||||
"full_answer": full_answer,
|
"full_answer": full_answer,
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"gt_ans": gt_answer,
|
"gt_ans": gt_answer,
|
||||||
|
"choice": choice,
|
||||||
"source": source})
|
"source": source})
|
||||||
return pred_qa_pairs
|
return pred_qa_pairs
|
||||||
|
|
||||||
@ -140,9 +151,7 @@ class InstructionTask(BaseTask):
|
|||||||
total_results = list()
|
total_results = list()
|
||||||
for sub_data_loader in data_loader.loaders:
|
for sub_data_loader in data_loader.loaders:
|
||||||
results = []
|
results = []
|
||||||
ques_ids = []
|
|
||||||
for samples in metric_logger.log_every(sub_data_loader, print_freq, header):
|
for samples in metric_logger.log_every(sub_data_loader, print_freq, header):
|
||||||
ques_ids.extend(samples['question_id'].tolist())
|
|
||||||
|
|
||||||
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
|
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
|
||||||
eval_output = self.valid_step(model=model, samples=samples)
|
eval_output = self.valid_step(model=model, samples=samples)
|
||||||
@ -168,6 +177,7 @@ class InstructionTask(BaseTask):
|
|||||||
filename=f"{split_name}_vqa_result_{source}",
|
filename=f"{split_name}_vqa_result_{source}",
|
||||||
remove_duplicate="question_id",
|
remove_duplicate="question_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
if source in ['vqav2','okvqa']:
|
if source in ['vqav2','okvqa']:
|
||||||
try:
|
try:
|
||||||
metrics = self._report_metrics_coco_vqa(result_file=result_file, split=split_name, source=source)
|
metrics = self._report_metrics_coco_vqa(result_file=result_file, split=split_name, source=source)
|
||||||
@ -180,7 +190,18 @@ class InstructionTask(BaseTask):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
metrics = None
|
metrics = None
|
||||||
print(f"Report Metrics {source} Error: {e}")
|
print(f"Report Metrics {source} Error: {e}")
|
||||||
|
elif source in ['aokvqa']:
|
||||||
|
try:
|
||||||
|
metrics = self._report_metrics_aokvqa(result_file=result_file, source=source)
|
||||||
|
except Exception as e:
|
||||||
|
metrics = None
|
||||||
|
print(f"Report Metrics {source} Error: {e}")
|
||||||
|
elif source in ['coco_cap']:
|
||||||
|
try:
|
||||||
|
metrics = self._report_metrics_caption(result_file=result_file, split_name=split_name, source=source)
|
||||||
|
except Exception as e:
|
||||||
|
metrics = None
|
||||||
|
print(f"Report Metrics {source} Error: {e}")
|
||||||
else:
|
else:
|
||||||
metrics = None
|
metrics = None
|
||||||
final_metrics[source] = metrics
|
final_metrics[source] = metrics
|
||||||
@ -234,10 +255,46 @@ class InstructionTask(BaseTask):
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
@dist_utils.main_process
|
||||||
|
def _report_metrics_aokvqa(self, result_file, source='aokvqa'):
|
||||||
|
"""
|
||||||
|
Validation of aokvqa
|
||||||
|
"""
|
||||||
|
# measuring accuracy compared to answer
|
||||||
|
results = json.load(open(result_file, "r"))
|
||||||
|
acc = []
|
||||||
|
vqa_tool = VQAEval()
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
|
||||||
|
gt_ans = res["choice"]
|
||||||
|
pred = res["answer"]
|
||||||
|
|
||||||
|
pred = vqa_tool.processPunctuation(pred)
|
||||||
|
pred = vqa_tool.processDigitArticle(pred)
|
||||||
|
|
||||||
|
# vqa_acc = 1 if pred == gt_ans else 0
|
||||||
|
vqa_acc = 1 if pred in gt_ans else 0
|
||||||
|
|
||||||
|
acc.append(vqa_acc)
|
||||||
|
|
||||||
|
accuracy = sum(acc) / len(acc) * 100
|
||||||
|
metrics = {"agg_metrics": accuracy, "acc": accuracy}
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a"
|
||||||
|
) as f:
|
||||||
|
f.write(json.dumps(metrics) + "\n")
|
||||||
|
|
||||||
|
logging.info(metrics)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
@dist_utils.main_process
|
@dist_utils.main_process
|
||||||
def _report_metrics_gqa(self, result_file, source='gqa'):
|
def _report_metrics_gqa(self, result_file, source='gqa'):
|
||||||
"""
|
"""
|
||||||
Validation of GQA/VQAv2
|
Validation of GQA
|
||||||
"""
|
"""
|
||||||
# measuring accuracy compared to answer
|
# measuring accuracy compared to answer
|
||||||
results = json.load(open(result_file, "r"))
|
results = json.load(open(result_file, "r"))
|
||||||
@ -274,3 +331,90 @@ class InstructionTask(BaseTask):
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
@dist_utils.main_process
|
||||||
|
def _report_metrics_caption(self, result_file, split_name, source='coco_cap'):
|
||||||
|
"""
|
||||||
|
Use official COCO Cap evaluation script to report metrics.
|
||||||
|
"""
|
||||||
|
coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
|
||||||
|
coco_val = coco_caption_eval(coco_gt_root, result_file, split_name)
|
||||||
|
|
||||||
|
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
|
||||||
|
log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
||||||
|
) as f:
|
||||||
|
f.write(json.dumps(log_stats) + "\n")
|
||||||
|
|
||||||
|
coco_res = {k: v for k, v in coco_val.eval.items()}
|
||||||
|
coco_res["agg_metrics"] = agg_metrics
|
||||||
|
|
||||||
|
return coco_res
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from pycocoevalcap.eval import COCOEvalCap
|
||||||
|
class COCO_Annotation:
|
||||||
|
def __init__(self, annotation_file):
|
||||||
|
self.coco_cn_file = annotation_file
|
||||||
|
self.imgToAnns = self.build_imgToAnns()
|
||||||
|
|
||||||
|
def build_imgToAnns(self):
|
||||||
|
imgToAnns = defaultdict(list)
|
||||||
|
with open(self.coco_cn_file, "r", encoding="UTF-8") as fin:
|
||||||
|
for line in fin:
|
||||||
|
line = line.strip()
|
||||||
|
temp = eval(line)
|
||||||
|
annotations = temp['annotations']
|
||||||
|
for ann in annotations:
|
||||||
|
image_id = str(ann['image_id']).zfill(6)
|
||||||
|
imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']})
|
||||||
|
return imgToAnns
|
||||||
|
|
||||||
|
def getImgIds(self):
|
||||||
|
return self.imgToAnns.keys()
|
||||||
|
|
||||||
|
class COCO_Result:
|
||||||
|
def __init__(self,result_file):
|
||||||
|
self.coco_cn_file = result_file
|
||||||
|
self.imgToAnns = self.build_imgToAnns()
|
||||||
|
|
||||||
|
def build_imgToAnns(self):
|
||||||
|
imgToAnns = dict()
|
||||||
|
data = json.load(open(self.coco_cn_file, "r"))
|
||||||
|
for d in data:
|
||||||
|
tmp = {
|
||||||
|
'image_id':d['question_id'][-6:],
|
||||||
|
'caption':d['answer']
|
||||||
|
}
|
||||||
|
imgToAnns[d['question_id'][-6:]] = [tmp]
|
||||||
|
return imgToAnns
|
||||||
|
|
||||||
|
def coco_caption_eval(coco_gt_root, results_file, split_name):
|
||||||
|
files = {
|
||||||
|
"val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json",
|
||||||
|
"test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# create coco object and coco_result object
|
||||||
|
annotation_file = files[split_name]
|
||||||
|
coco = COCO_Annotation(annotation_file)
|
||||||
|
coco_result = COCO_Result(results_file)
|
||||||
|
|
||||||
|
# create coco_eval object by taking coco and coco_result
|
||||||
|
coco_eval = COCOEvalCap(coco, coco_result)
|
||||||
|
|
||||||
|
# evaluate on a subset of images by setting
|
||||||
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||||
|
# please remove this line when evaluating the full validation set
|
||||||
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||||
|
|
||||||
|
# evaluate results
|
||||||
|
# SPICE will take a few minutes the first time, but speeds up due to caching
|
||||||
|
coco_eval.evaluate()
|
||||||
|
|
||||||
|
# print output evaluation scores
|
||||||
|
for metric, score in coco_eval.eval.items():
|
||||||
|
print(f"{metric}: {score:.3f}")
|
||||||
|
|
||||||
|
return coco_eval
|
@ -1,4 +0,0 @@
|
|||||||
<Img><ImageHere></Img> Describe this image in detail.
|
|
||||||
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
|
||||||
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
|
||||||
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|
|
58
test/datasets/test_dataset.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import datasets
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/models/cmrc2018_trial.json"
|
||||||
|
# dataset = load_dataset("json", data_files=[path], field="data", split="train")
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||||
|
# def preprocess_function(example):
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# model_inputs = tokenizer(example["content"], max_length=512, truncation=True)
|
||||||
|
# labels = tokenizer(example["title"], max_length=32, truncation=True)
|
||||||
|
# # label就是title编码的结果
|
||||||
|
# model_inputs["labels"] = labels["input_ids"]
|
||||||
|
# return model_inputs
|
||||||
|
# processed_datasets = dataset.map(preprocess_function)
|
||||||
|
|
||||||
|
dataset = load_dataset("/mnt/pfs-guan-ssai/nlu/wanghanzi/data/alpaca_20k")
|
||||||
|
train_dataset = dataset['train']
|
||||||
|
|
||||||
|
|
||||||
|
for i in tqdm(range(1, len(train_dataset))):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
idx = random.randint(0,i)
|
||||||
|
memory = train_dataset[idx]
|
||||||
|
memory_text = f"Instruction: {memory['instruction']}\n Answer: {memory['output']} \n"
|
||||||
|
train_dataset[i]['text'] = f"{memory_text} Instruction:{train_dataset[i]['instruction']}"
|
||||||
|
|
||||||
|
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
model_path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/opt_350m"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def formatting_prompts_func(example):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
output_texts = []
|
||||||
|
for i in range(len(example['instruction'])):
|
||||||
|
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
|
||||||
|
output_texts.append(text)
|
||||||
|
return output_texts
|
||||||
|
|
||||||
|
response_template = " ### Answer:"
|
||||||
|
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
formatting_func=formatting_prompts_func,
|
||||||
|
data_collator=collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|