update env name to minigptv, add bug report and feature request back

This commit is contained in:
Deyao Zhu 2023-10-23 20:42:36 +03:00
parent b552d23fff
commit 8c2297750f
7 changed files with 24 additions and 19 deletions

View File

@ -178,7 +178,6 @@ class MiniGPTBase(BaseModel):
answers = [self.llama_tokenizer(a + self.end_sym,
return_tensors="pt",
add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):
@ -226,8 +225,6 @@ class MiniGPTBase(BaseModel):
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)

View File

@ -12,6 +12,7 @@ import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb
import minigpt4.tasks as tasks
from minigpt4.common.config import Config
@ -30,7 +31,6 @@ from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
import wandb
def parse_args():
@ -44,12 +44,10 @@ def parse_args():
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
parser.add_argument("--wandb_log", default=False)
parser.add_argument("--job_name",default="minigpt_v2",type=str)
parser.add_argument("--job_name", default="minigpt_v2",type=str)
args = parser.parse_args()
return args
@ -80,16 +78,13 @@ def main():
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
job_id = now()
args = parse_args()
cfg = Config(parse_args())
cfg = Config(args)
init_distributed_mode(cfg.run_cfg)
setup_seeds(cfg)
# set after init_distributed_mode() to only log on master.
setup_logger()
cfg.pretty_print()
task = tasks.setup_task(cfg)
@ -98,10 +93,9 @@ def main():
if cfg.run_cfg.wandb_log:
wandb.login()
wandb.init(project="minigptv2",name=args.job_name)
wandb.init(project="minigptv", name=cfg.run_cfg.job_name)
wandb.watch(model)
runner = get_runner_class(cfg)(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)

View File

@ -52,4 +52,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
distributed: True
wandb_log: True
job_name: minigpt4_llama2_pretrain

View File

@ -46,4 +46,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
distributed: True
wandb_log: True
job_name: minigpt4_llama2_finetune

View File

@ -52,4 +52,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
distributed: True
wandb_log: True
job_name: minigpt4_pretrain

View File

@ -46,4 +46,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
distributed: True
wandb_log: True
job_name: minigpt4_finetune

View File

@ -276,7 +276,6 @@ run:
init_lr: 1e-5
min_lr: 8e-5
warmup_lr: 1e-6
wandb_log: True
weight_decay: 0.05
max_epoch: 50
@ -296,4 +295,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True
distributed: True
wandb_log: True
job_name: minigptv2_finetune