diff --git a/minigpt4/datasets/data_utils.py b/minigpt4/datasets/data_utils.py index cf6497f..773b10f 100644 --- a/minigpt4/datasets/data_utils.py +++ b/minigpt4/datasets/data_utils.py @@ -96,7 +96,7 @@ def prepare_sample(samples, cuda_enabled=True): return samples -def reorg_datasets_by_split(datasets): +def reorg_datasets_by_split(datasets, batch_sizes): """ Organizes datasets by split. @@ -110,16 +110,19 @@ def reorg_datasets_by_split(datasets): # return datasets[list(datasets.keys())[0]] # else: reorg_datasets = dict() + reorg_batch_sizes = dict() # reorganize by split - for _, dataset in datasets.items(): + for dataset_name, dataset in datasets.items(): for split_name, dataset_split in dataset.items(): if split_name not in reorg_datasets: reorg_datasets[split_name] = [dataset_split] + reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]] else: reorg_datasets[split_name].append(dataset_split) + reorg_batch_sizes[split_name].append(batch_sizes[dataset_name]) - return reorg_datasets + return reorg_datasets, reorg_batch_sizes def concat_datasets(datasets): diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py index ccb5706..bc8dc1d 100644 --- a/minigpt4/runners/runner_base.py +++ b/minigpt4/runners/runner_base.py @@ -88,7 +88,7 @@ class RunnerBase: if self.use_distributed: if self._wrapped_model is None: self._wrapped_model = DDP( - self._model, device_ids=[self.config.run_cfg.gpu] + self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True ) else: self._wrapped_model = self._model @@ -206,7 +206,9 @@ class RunnerBase: "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." ) - datasets = reorg_datasets_by_split(self.datasets) + batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size + for dataset_name in self.datasets.keys()} + datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes) self.datasets = datasets # self.datasets = concat_datasets(datasets) @@ -247,14 +249,10 @@ class RunnerBase: split_names = sorted(self.datasets.keys()) datasets = [self.datasets[split] for split in split_names] + batch_sizes = [batch_sizes[split] for split in split_names] is_trains = [split in self.train_splits for split in split_names] - batch_sizes = [ - self.config.run_cfg.batch_size_train - if split == "train" - else self.config.run_cfg.batch_size_eval - for split in split_names - ] + print("batch sizes", batch_sizes) collate_fns = [] for dataset in datasets: @@ -349,6 +347,7 @@ class RunnerBase: lib_root = Path(registry.get_path("library_root")) output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id + # output_dir = lib_root / self.config.run_cfg.output_dir result_dir = output_dir / "result" output_dir.mkdir(parents=True, exist_ok=True) @@ -519,6 +518,7 @@ class RunnerBase: else: # map-style dataset are concatenated together # setup distributed sampler + if self.use_distributed: sampler = DistributedSampler( dataset, @@ -559,7 +559,7 @@ class RunnerBase: dataset_ratios = [d.sample_ratio for d in dataset] loader = MultiIterLoader( loaders=[ - _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) + _create_loader(d, num_workers, bsz[i], is_train, collate_fn[i]) for i, d in enumerate(dataset) ], ratios=dataset_ratios, @@ -634,13 +634,14 @@ class RunnerBase: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] - self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) + message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) self.optimizer.load_state_dict(checkpoint["optimizer"]) if self.scaler and "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler"]) self.start_epoch = checkpoint["epoch"] + 1 + print("resume the checkpoint") logging.info("Resume checkpoint from {}".format(url_or_filename)) @main_process diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py index 7ceee96..e5eb32b 100644 --- a/minigpt4/tasks/base_task.py +++ b/minigpt4/tasks/base_task.py @@ -21,12 +21,14 @@ class BaseTask: super().__init__() self.inst_id_key = "instance_id" + self.cfg = "" @classmethod def setup_task(cls, **kwargs): return cls() def build_model(self, cfg): + self.cfg = cfg model_config = cfg.model_cfg model_cls = registry.get_model_class(model_config.arch) diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml index f3981b8..c13d31f 100644 --- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -5,6 +5,7 @@ model: datasets: laion: + batch_size: 64 vis_processor: train: name: "blip2_image_train" @@ -14,6 +15,7 @@ datasets: name: "blip_caption" sample_ratio: 115 cc_sbu: + batch_size: 64 vis_processor: train: name: "blip2_image_train" @@ -34,8 +36,6 @@ run: weight_decay: 0.05 max_epoch: 4 - batch_size_train: 64 - batch_size_eval: 64 num_workers: 4 warmup_steps: 5000 iters_per_epoch: 5000 diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml index fa2b578..8c138ae 100644 --- a/train_configs/minigpt4_llama2_stage2_finetune.yaml +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -11,6 +11,7 @@ model: datasets: cc_sbu_align: + batch_size: 12 vis_processor: train: name: "blip2_image_train" @@ -30,8 +31,6 @@ run: weight_decay: 0.05 max_epoch: 5 iters_per_epoch: 200 - batch_size_train: 12 - batch_size_eval: 12 num_workers: 4 warmup_steps: 200 diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml index be87b77..ce8bc87 100644 --- a/train_configs/minigpt4_stage1_pretrain.yaml +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -5,6 +5,7 @@ model: datasets: laion: + batch_size: 64 vis_processor: train: name: "blip2_image_train" @@ -14,6 +15,7 @@ datasets: name: "blip_caption" sample_ratio: 115 cc_sbu: + batch_size: 64 vis_processor: train: name: "blip2_image_train" @@ -34,8 +36,6 @@ run: weight_decay: 0.05 max_epoch: 4 - batch_size_train: 64 - batch_size_eval: 64 num_workers: 4 warmup_steps: 5000 iters_per_epoch: 5000 diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml index 404dfd6..531a3a0 100644 --- a/train_configs/minigpt4_stage2_finetune.yaml +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -11,6 +11,7 @@ model: datasets: cc_sbu_align: + batch_size: 12 vis_processor: train: name: "blip2_image_train" @@ -30,8 +31,6 @@ run: weight_decay: 0.05 max_epoch: 5 iters_per_epoch: 200 - batch_size_train: 12 - batch_size_eval: 12 num_workers: 4 warmup_steps: 200