mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
Merge pull request #390 from TsuTikgiau/main
update the code for batchsize setting
This commit is contained in:
commit
10f61a4dd8
@ -96,7 +96,7 @@ def prepare_sample(samples, cuda_enabled=True):
|
||||
return samples
|
||||
|
||||
|
||||
def reorg_datasets_by_split(datasets):
|
||||
def reorg_datasets_by_split(datasets, batch_sizes):
|
||||
"""
|
||||
Organizes datasets by split.
|
||||
|
||||
@ -110,16 +110,19 @@ def reorg_datasets_by_split(datasets):
|
||||
# return datasets[list(datasets.keys())[0]]
|
||||
# else:
|
||||
reorg_datasets = dict()
|
||||
reorg_batch_sizes = dict()
|
||||
|
||||
# reorganize by split
|
||||
for _, dataset in datasets.items():
|
||||
for dataset_name, dataset in datasets.items():
|
||||
for split_name, dataset_split in dataset.items():
|
||||
if split_name not in reorg_datasets:
|
||||
reorg_datasets[split_name] = [dataset_split]
|
||||
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
|
||||
else:
|
||||
reorg_datasets[split_name].append(dataset_split)
|
||||
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
|
||||
|
||||
return reorg_datasets
|
||||
return reorg_datasets, reorg_batch_sizes
|
||||
|
||||
|
||||
def concat_datasets(datasets):
|
||||
|
@ -88,7 +88,7 @@ class RunnerBase:
|
||||
if self.use_distributed:
|
||||
if self._wrapped_model is None:
|
||||
self._wrapped_model = DDP(
|
||||
self._model, device_ids=[self.config.run_cfg.gpu]
|
||||
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
|
||||
)
|
||||
else:
|
||||
self._wrapped_model = self._model
|
||||
@ -206,7 +206,9 @@ class RunnerBase:
|
||||
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
|
||||
)
|
||||
|
||||
datasets = reorg_datasets_by_split(self.datasets)
|
||||
batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
|
||||
for dataset_name in self.datasets.keys()}
|
||||
datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
|
||||
self.datasets = datasets
|
||||
# self.datasets = concat_datasets(datasets)
|
||||
|
||||
@ -247,14 +249,10 @@ class RunnerBase:
|
||||
split_names = sorted(self.datasets.keys())
|
||||
|
||||
datasets = [self.datasets[split] for split in split_names]
|
||||
batch_sizes = [batch_sizes[split] for split in split_names]
|
||||
is_trains = [split in self.train_splits for split in split_names]
|
||||
|
||||
batch_sizes = [
|
||||
self.config.run_cfg.batch_size_train
|
||||
if split == "train"
|
||||
else self.config.run_cfg.batch_size_eval
|
||||
for split in split_names
|
||||
]
|
||||
print("batch sizes", batch_sizes)
|
||||
|
||||
collate_fns = []
|
||||
for dataset in datasets:
|
||||
@ -349,6 +347,7 @@ class RunnerBase:
|
||||
lib_root = Path(registry.get_path("library_root"))
|
||||
|
||||
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
|
||||
# output_dir = lib_root / self.config.run_cfg.output_dir
|
||||
result_dir = output_dir / "result"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -519,6 +518,7 @@ class RunnerBase:
|
||||
else:
|
||||
# map-style dataset are concatenated together
|
||||
# setup distributed sampler
|
||||
|
||||
if self.use_distributed:
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
@ -559,7 +559,7 @@ class RunnerBase:
|
||||
dataset_ratios = [d.sample_ratio for d in dataset]
|
||||
loader = MultiIterLoader(
|
||||
loaders=[
|
||||
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
|
||||
_create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
|
||||
for i, d in enumerate(dataset)
|
||||
],
|
||||
ratios=dataset_ratios,
|
||||
@ -634,13 +634,14 @@ class RunnerBase:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
state_dict = checkpoint["model"]
|
||||
self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
|
||||
message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
|
||||
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if self.scaler and "scaler" in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint["scaler"])
|
||||
|
||||
self.start_epoch = checkpoint["epoch"] + 1
|
||||
print("resume the checkpoint")
|
||||
logging.info("Resume checkpoint from {}".format(url_or_filename))
|
||||
|
||||
@main_process
|
||||
|
@ -21,12 +21,14 @@ class BaseTask:
|
||||
super().__init__()
|
||||
|
||||
self.inst_id_key = "instance_id"
|
||||
self.cfg = ""
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, **kwargs):
|
||||
return cls()
|
||||
|
||||
def build_model(self, cfg):
|
||||
self.cfg = cfg
|
||||
model_config = cfg.model_cfg
|
||||
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
|
@ -5,6 +5,7 @@ model:
|
||||
|
||||
datasets:
|
||||
laion:
|
||||
batch_size: 64
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -14,6 +15,7 @@ datasets:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 115
|
||||
cc_sbu:
|
||||
batch_size: 64
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -34,8 +36,6 @@ run:
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 4
|
||||
batch_size_train: 64
|
||||
batch_size_eval: 64
|
||||
num_workers: 4
|
||||
warmup_steps: 5000
|
||||
iters_per_epoch: 5000
|
||||
|
@ -11,6 +11,7 @@ model:
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
batch_size: 12
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -30,8 +31,6 @@ run:
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
iters_per_epoch: 200
|
||||
batch_size_train: 12
|
||||
batch_size_eval: 12
|
||||
num_workers: 4
|
||||
warmup_steps: 200
|
||||
|
||||
|
@ -5,6 +5,7 @@ model:
|
||||
|
||||
datasets:
|
||||
laion:
|
||||
batch_size: 64
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -14,6 +15,7 @@ datasets:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 115
|
||||
cc_sbu:
|
||||
batch_size: 64
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -34,8 +36,6 @@ run:
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 4
|
||||
batch_size_train: 64
|
||||
batch_size_eval: 64
|
||||
num_workers: 4
|
||||
warmup_steps: 5000
|
||||
iters_per_epoch: 5000
|
||||
|
@ -11,6 +11,7 @@ model:
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
batch_size: 12
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -30,8 +31,6 @@ run:
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
iters_per_epoch: 200
|
||||
batch_size_train: 12
|
||||
batch_size_eval: 12
|
||||
num_workers: 4
|
||||
warmup_steps: 200
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user