Merge pull request #390 from TsuTikgiau/main

update the code for batchsize setting
This commit is contained in:
ZhuDeyao 2023-10-21 22:13:43 +03:00 committed by GitHub
commit 10f61a4dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 25 additions and 21 deletions

View File

@ -96,7 +96,7 @@ def prepare_sample(samples, cuda_enabled=True):
return samples
def reorg_datasets_by_split(datasets):
def reorg_datasets_by_split(datasets, batch_sizes):
"""
Organizes datasets by split.
@ -110,16 +110,19 @@ def reorg_datasets_by_split(datasets):
# return datasets[list(datasets.keys())[0]]
# else:
reorg_datasets = dict()
reorg_batch_sizes = dict()
# reorganize by split
for _, dataset in datasets.items():
for dataset_name, dataset in datasets.items():
for split_name, dataset_split in dataset.items():
if split_name not in reorg_datasets:
reorg_datasets[split_name] = [dataset_split]
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
else:
reorg_datasets[split_name].append(dataset_split)
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
return reorg_datasets
return reorg_datasets, reorg_batch_sizes
def concat_datasets(datasets):

View File

@ -88,7 +88,7 @@ class RunnerBase:
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu]
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
)
else:
self._wrapped_model = self._model
@ -206,7 +206,9 @@ class RunnerBase:
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
)
datasets = reorg_datasets_by_split(self.datasets)
batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
for dataset_name in self.datasets.keys()}
datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
self.datasets = datasets
# self.datasets = concat_datasets(datasets)
@ -247,14 +249,10 @@ class RunnerBase:
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
batch_sizes = [batch_sizes[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
print("batch sizes", batch_sizes)
collate_fns = []
for dataset in datasets:
@ -349,6 +347,7 @@ class RunnerBase:
lib_root = Path(registry.get_path("library_root"))
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
# output_dir = lib_root / self.config.run_cfg.output_dir
result_dir = output_dir / "result"
output_dir.mkdir(parents=True, exist_ok=True)
@ -519,6 +518,7 @@ class RunnerBase:
else:
# map-style dataset are concatenated together
# setup distributed sampler
if self.use_distributed:
sampler = DistributedSampler(
dataset,
@ -559,7 +559,7 @@ class RunnerBase:
dataset_ratios = [d.sample_ratio for d in dataset]
loader = MultiIterLoader(
loaders=[
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
_create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
for i, d in enumerate(dataset)
],
ratios=dataset_ratios,
@ -634,13 +634,14 @@ class RunnerBase:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"] + 1
print("resume the checkpoint")
logging.info("Resume checkpoint from {}".format(url_or_filename))
@main_process

View File

@ -21,12 +21,14 @@ class BaseTask:
super().__init__()
self.inst_id_key = "instance_id"
self.cfg = ""
@classmethod
def setup_task(cls, **kwargs):
return cls()
def build_model(self, cfg):
self.cfg = cfg
model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)

View File

@ -5,6 +5,7 @@ model:
datasets:
laion:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
@ -14,6 +15,7 @@ datasets:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
@ -34,8 +36,6 @@ run:
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000

View File

@ -11,6 +11,7 @@ model:
datasets:
cc_sbu_align:
batch_size: 12
vis_processor:
train:
name: "blip2_image_train"
@ -30,8 +31,6 @@ run:
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
batch_size_train: 12
batch_size_eval: 12
num_workers: 4
warmup_steps: 200

View File

@ -5,6 +5,7 @@ model:
datasets:
laion:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
@ -14,6 +15,7 @@ datasets:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
batch_size: 64
vis_processor:
train:
name: "blip2_image_train"
@ -34,8 +36,6 @@ run:
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000

View File

@ -11,6 +11,7 @@ model:
datasets:
cc_sbu_align:
batch_size: 12
vis_processor:
train:
name: "blip2_image_train"
@ -30,8 +31,6 @@ run:
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
batch_size_train: 12
batch_size_eval: 12
num_workers: 4
warmup_steps: 200