MiniGPT-4/test/datasets/test_pretrain_task_dataset.py
2023-12-01 23:17:44 +08:00

203 lines
6.2 KiB
Python

"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
# import wandb
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
import minigpt4.tasks as tasks
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.logger import setup_logger
from minigpt4.common.registry import registry
from minigpt4.common.utils import now
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
# parser.add_argument("-f", help="jupyter notebook")
parser.add_argument(
"--cfg-path",
# default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/minigpt/train/minigptv2_finetune_vqa.yaml",
default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/prompt_moe/train/mix_coco_gqa_prompt_moe_post_blip2.yaml",
help="path to configuration file.")
parser.add_argument(
"--gpu-id",
type=int,
default=5,
help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# Test About Building Task
# build config
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
cfg = Config(parse_args())
setup_seeds(cfg)
print(cfg._convert_node_to_json(cfg.config))
setup_logger()
cfg.pretty_print()
task = tasks.setup_task(cfg)
datasets = task.build_datasets(cfg)
def get_runner_class(cfg):
"""
Get runner class from config. Default to epoch-based runner.
"""
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
return runner_cls
job_id = now()
model = task.build_model(cfg)
# model = None
# task.build_tensorboard(cfg)
runner = get_runner_class(cfg)(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)
# from minigpt4.common.vqa_tools.vqa import VQA
# from minigpt4.common.vqa_tools.vqa_eval import VQAEval
# split = 'val'
# source = 'vqav2'
# print(task.anno_files[split][source],task.ques_files[split][source])
# vqa = VQA(task.anno_files[split][source], task.ques_files[split][source])
# result_file = '/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1101/20231031224/result/val_vqa_result_2.json'
# print('result_file: ',result_file)
# vqa_result = vqa.loadRes(resFile=result_file, quesFile=task.ques_files[split][source])
# vqa_scorer = VQAEval(vqa, vqa_result, n=2)
# vqa_scorer.evaluate()
# overall_acc = vqa_scorer.accuracy["overall"]
# perAnswerType = vqa_scorer.accuracy["perAnswerType"]
# print(overall_acc)
# print(perAnswerType)
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
import webdataset as wds
import logging
batch_sizes = {dataset_name: getattr(runner.config.datasets_cfg, dataset_name).batch_size
for dataset_name in runner.datasets.keys()}
print(batch_sizes)
datasets, batch_sizes = reorg_datasets_by_split(runner.datasets, batch_sizes)
runner.datasets = datasets
# self.datasets = concat_datasets(datasets)
print(runner.datasets.keys()) # dict_keys(['train', 'val', 'test'])
# print dataset statistics after concatenation/chaining
for split_name in runner.datasets:
if isinstance(runner.datasets[split_name], tuple) or isinstance(
runner.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in runner.datasets[split_name]
]
)
else:
if hasattr(runner.datasets[split_name], "__len__"):
# a single map-style dataset
num_records = len(runner.datasets[split_name])
else:
# a single wds.DataPipeline
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
split_names = sorted(runner.datasets.keys())
datasets = [runner.datasets[split] for split in split_names]
batch_sizes = [batch_sizes[split] for split in split_names]
is_trains = [split in runner.train_splits for split in split_names]
print("split_names: ",split_names)
print("is_trains: ",is_trains)
print("batch sizes: ", batch_sizes)
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = runner.create_loaders(
datasets=datasets,
num_workers=runner.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
)
_dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
loader = _dataloaders['train']
loader_idx = random.choices(range(len(loader.loaders)), loader.ratios, k=1)[0]
print(loader_idx)
next(loader.loaders[loader_idx])['question_id'], next(loader.loaders[loader_idx])['source']