mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
Merge pull request #390 from TsuTikgiau/main
update the code for batchsize setting
This commit is contained in:
commit
10f61a4dd8
@ -96,7 +96,7 @@ def prepare_sample(samples, cuda_enabled=True):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def reorg_datasets_by_split(datasets):
|
def reorg_datasets_by_split(datasets, batch_sizes):
|
||||||
"""
|
"""
|
||||||
Organizes datasets by split.
|
Organizes datasets by split.
|
||||||
|
|
||||||
@ -110,16 +110,19 @@ def reorg_datasets_by_split(datasets):
|
|||||||
# return datasets[list(datasets.keys())[0]]
|
# return datasets[list(datasets.keys())[0]]
|
||||||
# else:
|
# else:
|
||||||
reorg_datasets = dict()
|
reorg_datasets = dict()
|
||||||
|
reorg_batch_sizes = dict()
|
||||||
|
|
||||||
# reorganize by split
|
# reorganize by split
|
||||||
for _, dataset in datasets.items():
|
for dataset_name, dataset in datasets.items():
|
||||||
for split_name, dataset_split in dataset.items():
|
for split_name, dataset_split in dataset.items():
|
||||||
if split_name not in reorg_datasets:
|
if split_name not in reorg_datasets:
|
||||||
reorg_datasets[split_name] = [dataset_split]
|
reorg_datasets[split_name] = [dataset_split]
|
||||||
|
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
|
||||||
else:
|
else:
|
||||||
reorg_datasets[split_name].append(dataset_split)
|
reorg_datasets[split_name].append(dataset_split)
|
||||||
|
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
|
||||||
|
|
||||||
return reorg_datasets
|
return reorg_datasets, reorg_batch_sizes
|
||||||
|
|
||||||
|
|
||||||
def concat_datasets(datasets):
|
def concat_datasets(datasets):
|
||||||
|
@ -88,7 +88,7 @@ class RunnerBase:
|
|||||||
if self.use_distributed:
|
if self.use_distributed:
|
||||||
if self._wrapped_model is None:
|
if self._wrapped_model is None:
|
||||||
self._wrapped_model = DDP(
|
self._wrapped_model = DDP(
|
||||||
self._model, device_ids=[self.config.run_cfg.gpu]
|
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._wrapped_model = self._model
|
self._wrapped_model = self._model
|
||||||
@ -206,7 +206,9 @@ class RunnerBase:
|
|||||||
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
|
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
|
||||||
)
|
)
|
||||||
|
|
||||||
datasets = reorg_datasets_by_split(self.datasets)
|
batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
|
||||||
|
for dataset_name in self.datasets.keys()}
|
||||||
|
datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
|
||||||
self.datasets = datasets
|
self.datasets = datasets
|
||||||
# self.datasets = concat_datasets(datasets)
|
# self.datasets = concat_datasets(datasets)
|
||||||
|
|
||||||
@ -247,14 +249,10 @@ class RunnerBase:
|
|||||||
split_names = sorted(self.datasets.keys())
|
split_names = sorted(self.datasets.keys())
|
||||||
|
|
||||||
datasets = [self.datasets[split] for split in split_names]
|
datasets = [self.datasets[split] for split in split_names]
|
||||||
|
batch_sizes = [batch_sizes[split] for split in split_names]
|
||||||
is_trains = [split in self.train_splits for split in split_names]
|
is_trains = [split in self.train_splits for split in split_names]
|
||||||
|
|
||||||
batch_sizes = [
|
print("batch sizes", batch_sizes)
|
||||||
self.config.run_cfg.batch_size_train
|
|
||||||
if split == "train"
|
|
||||||
else self.config.run_cfg.batch_size_eval
|
|
||||||
for split in split_names
|
|
||||||
]
|
|
||||||
|
|
||||||
collate_fns = []
|
collate_fns = []
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
@ -349,6 +347,7 @@ class RunnerBase:
|
|||||||
lib_root = Path(registry.get_path("library_root"))
|
lib_root = Path(registry.get_path("library_root"))
|
||||||
|
|
||||||
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
|
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
|
||||||
|
# output_dir = lib_root / self.config.run_cfg.output_dir
|
||||||
result_dir = output_dir / "result"
|
result_dir = output_dir / "result"
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -519,6 +518,7 @@ class RunnerBase:
|
|||||||
else:
|
else:
|
||||||
# map-style dataset are concatenated together
|
# map-style dataset are concatenated together
|
||||||
# setup distributed sampler
|
# setup distributed sampler
|
||||||
|
|
||||||
if self.use_distributed:
|
if self.use_distributed:
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
dataset,
|
dataset,
|
||||||
@ -559,7 +559,7 @@ class RunnerBase:
|
|||||||
dataset_ratios = [d.sample_ratio for d in dataset]
|
dataset_ratios = [d.sample_ratio for d in dataset]
|
||||||
loader = MultiIterLoader(
|
loader = MultiIterLoader(
|
||||||
loaders=[
|
loaders=[
|
||||||
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
|
_create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
|
||||||
for i, d in enumerate(dataset)
|
for i, d in enumerate(dataset)
|
||||||
],
|
],
|
||||||
ratios=dataset_ratios,
|
ratios=dataset_ratios,
|
||||||
@ -634,13 +634,14 @@ class RunnerBase:
|
|||||||
raise RuntimeError("checkpoint url or path is invalid")
|
raise RuntimeError("checkpoint url or path is invalid")
|
||||||
|
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
|
message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
|
||||||
|
|
||||||
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
if self.scaler and "scaler" in checkpoint:
|
if self.scaler and "scaler" in checkpoint:
|
||||||
self.scaler.load_state_dict(checkpoint["scaler"])
|
self.scaler.load_state_dict(checkpoint["scaler"])
|
||||||
|
|
||||||
self.start_epoch = checkpoint["epoch"] + 1
|
self.start_epoch = checkpoint["epoch"] + 1
|
||||||
|
print("resume the checkpoint")
|
||||||
logging.info("Resume checkpoint from {}".format(url_or_filename))
|
logging.info("Resume checkpoint from {}".format(url_or_filename))
|
||||||
|
|
||||||
@main_process
|
@main_process
|
||||||
|
@ -21,12 +21,14 @@ class BaseTask:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.inst_id_key = "instance_id"
|
self.inst_id_key = "instance_id"
|
||||||
|
self.cfg = ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_task(cls, **kwargs):
|
def setup_task(cls, **kwargs):
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
def build_model(self, cfg):
|
def build_model(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
model_config = cfg.model_cfg
|
model_config = cfg.model_cfg
|
||||||
|
|
||||||
model_cls = registry.get_model_class(model_config.arch)
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
@ -5,6 +5,7 @@ model:
|
|||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
laion:
|
laion:
|
||||||
|
batch_size: 64
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -14,6 +15,7 @@ datasets:
|
|||||||
name: "blip_caption"
|
name: "blip_caption"
|
||||||
sample_ratio: 115
|
sample_ratio: 115
|
||||||
cc_sbu:
|
cc_sbu:
|
||||||
|
batch_size: 64
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -34,8 +36,6 @@ run:
|
|||||||
|
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 4
|
max_epoch: 4
|
||||||
batch_size_train: 64
|
|
||||||
batch_size_eval: 64
|
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
warmup_steps: 5000
|
warmup_steps: 5000
|
||||||
iters_per_epoch: 5000
|
iters_per_epoch: 5000
|
||||||
|
@ -11,6 +11,7 @@ model:
|
|||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
cc_sbu_align:
|
cc_sbu_align:
|
||||||
|
batch_size: 12
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -30,8 +31,6 @@ run:
|
|||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 5
|
max_epoch: 5
|
||||||
iters_per_epoch: 200
|
iters_per_epoch: 200
|
||||||
batch_size_train: 12
|
|
||||||
batch_size_eval: 12
|
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
warmup_steps: 200
|
warmup_steps: 200
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ model:
|
|||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
laion:
|
laion:
|
||||||
|
batch_size: 64
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -14,6 +15,7 @@ datasets:
|
|||||||
name: "blip_caption"
|
name: "blip_caption"
|
||||||
sample_ratio: 115
|
sample_ratio: 115
|
||||||
cc_sbu:
|
cc_sbu:
|
||||||
|
batch_size: 64
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -34,8 +36,6 @@ run:
|
|||||||
|
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 4
|
max_epoch: 4
|
||||||
batch_size_train: 64
|
|
||||||
batch_size_eval: 64
|
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
warmup_steps: 5000
|
warmup_steps: 5000
|
||||||
iters_per_epoch: 5000
|
iters_per_epoch: 5000
|
||||||
|
@ -11,6 +11,7 @@ model:
|
|||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
cc_sbu_align:
|
cc_sbu_align:
|
||||||
|
batch_size: 12
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
||||||
name: "blip2_image_train"
|
name: "blip2_image_train"
|
||||||
@ -30,8 +31,6 @@ run:
|
|||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 5
|
max_epoch: 5
|
||||||
iters_per_epoch: 200
|
iters_per_epoch: 200
|
||||||
batch_size_train: 12
|
|
||||||
batch_size_eval: 12
|
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
warmup_steps: 200
|
warmup_steps: 200
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user