diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index b5dd9d4..cd051ec 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -178,7 +178,6 @@ class MiniGPTBase(BaseModel): answers = [self.llama_tokenizer(a + self.end_sym, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers] - cur_id = [] cur_target = [] for i in range(len(questions)): @@ -226,8 +225,6 @@ class MiniGPTBase(BaseModel): conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] - - cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) diff --git a/train.py b/train.py index 072a078..0013079 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ import random import numpy as np import torch import torch.backends.cudnn as cudnn +import wandb import minigpt4.tasks as tasks from minigpt4.common.config import Config @@ -30,7 +31,6 @@ from minigpt4.models import * from minigpt4.processors import * from minigpt4.runners import * from minigpt4.tasks import * -import wandb def parse_args(): @@ -44,12 +44,10 @@ def parse_args(): "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) - parser.add_argument("--wandb_log", default=False) - parser.add_argument("--job_name",default="minigpt_v2",type=str) + parser.add_argument("--job_name", default="minigpt_v2",type=str) args = parser.parse_args() - return args @@ -80,16 +78,13 @@ def main(): # set before init_distributed_mode() to ensure the same job_id shared across all ranks. job_id = now() args = parse_args() - - cfg = Config(parse_args()) + cfg = Config(args) init_distributed_mode(cfg.run_cfg) - setup_seeds(cfg) # set after init_distributed_mode() to only log on master. setup_logger() - cfg.pretty_print() task = tasks.setup_task(cfg) @@ -98,10 +93,9 @@ def main(): if cfg.run_cfg.wandb_log: wandb.login() - wandb.init(project="minigptv2",name=args.job_name) + wandb.init(project="minigptv", name=cfg.run_cfg.job_name) wandb.watch(model) - runner = get_runner_class(cfg)( cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets ) diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml index c13d31f..bcc458e 100644 --- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -52,4 +52,7 @@ run: device: "cuda" world_size: 1 dist_url: "env://" - distributed: True \ No newline at end of file + distributed: True + + wandb_log: True + job_name: minigpt4_llama2_pretrain \ No newline at end of file diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml index 8c138ae..29b5358 100644 --- a/train_configs/minigpt4_llama2_stage2_finetune.yaml +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -46,4 +46,7 @@ run: device: "cuda" world_size: 1 dist_url: "env://" - distributed: True \ No newline at end of file + distributed: True + + wandb_log: True + job_name: minigpt4_llama2_finetune \ No newline at end of file diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml index ce8bc87..bd9a451 100644 --- a/train_configs/minigpt4_stage1_pretrain.yaml +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -52,4 +52,7 @@ run: device: "cuda" world_size: 1 dist_url: "env://" - distributed: True \ No newline at end of file + distributed: True + + wandb_log: True + job_name: minigpt4_pretrain \ No newline at end of file diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml index 531a3a0..89d1100 100644 --- a/train_configs/minigpt4_stage2_finetune.yaml +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -46,4 +46,7 @@ run: device: "cuda" world_size: 1 dist_url: "env://" - distributed: True \ No newline at end of file + distributed: True + + wandb_log: True + job_name: minigpt4_finetune \ No newline at end of file diff --git a/train_configs/minigpt_v2_finetune.yaml b/train_configs/minigpt_v2_finetune.yaml index 7bf6bbf..4039ea6 100644 --- a/train_configs/minigpt_v2_finetune.yaml +++ b/train_configs/minigpt_v2_finetune.yaml @@ -276,7 +276,6 @@ run: init_lr: 1e-5 min_lr: 8e-5 warmup_lr: 1e-6 - wandb_log: True weight_decay: 0.05 max_epoch: 50 @@ -296,4 +295,7 @@ run: device: "cuda" world_size: 1 dist_url: "env://" - distributed: True \ No newline at end of file + distributed: True + + wandb_log: True + job_name: minigptv2_finetune \ No newline at end of file