mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
update env name to minigptv, add bug report and feature request back
This commit is contained in:
parent
b552d23fff
commit
8c2297750f
@ -178,7 +178,6 @@ class MiniGPTBase(BaseModel):
|
|||||||
answers = [self.llama_tokenizer(a + self.end_sym,
|
answers = [self.llama_tokenizer(a + self.end_sym,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=False).to(self.device) for a in answers]
|
add_special_tokens=False).to(self.device) for a in answers]
|
||||||
|
|
||||||
cur_id = []
|
cur_id = []
|
||||||
cur_target = []
|
cur_target = []
|
||||||
for i in range(len(questions)):
|
for i in range(len(questions)):
|
||||||
@ -226,8 +225,6 @@ class MiniGPTBase(BaseModel):
|
|||||||
|
|
||||||
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
|
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
||||||
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
||||||
|
|
||||||
|
14
train.py
14
train.py
@ -12,6 +12,7 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
|
import wandb
|
||||||
|
|
||||||
import minigpt4.tasks as tasks
|
import minigpt4.tasks as tasks
|
||||||
from minigpt4.common.config import Config
|
from minigpt4.common.config import Config
|
||||||
@ -30,7 +31,6 @@ from minigpt4.models import *
|
|||||||
from minigpt4.processors import *
|
from minigpt4.processors import *
|
||||||
from minigpt4.runners import *
|
from minigpt4.runners import *
|
||||||
from minigpt4.tasks import *
|
from minigpt4.tasks import *
|
||||||
import wandb
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -44,12 +44,10 @@ def parse_args():
|
|||||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
"change to --cfg-options instead.",
|
"change to --cfg-options instead.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--wandb_log", default=False)
|
parser.add_argument("--job_name", default="minigpt_v2",type=str)
|
||||||
parser.add_argument("--job_name",default="minigpt_v2",type=str)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@ -80,16 +78,13 @@ def main():
|
|||||||
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
|
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
|
||||||
job_id = now()
|
job_id = now()
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
cfg = Config(parse_args())
|
|
||||||
|
|
||||||
init_distributed_mode(cfg.run_cfg)
|
init_distributed_mode(cfg.run_cfg)
|
||||||
|
|
||||||
setup_seeds(cfg)
|
setup_seeds(cfg)
|
||||||
|
|
||||||
# set after init_distributed_mode() to only log on master.
|
# set after init_distributed_mode() to only log on master.
|
||||||
setup_logger()
|
setup_logger()
|
||||||
|
|
||||||
cfg.pretty_print()
|
cfg.pretty_print()
|
||||||
|
|
||||||
task = tasks.setup_task(cfg)
|
task = tasks.setup_task(cfg)
|
||||||
@ -98,10 +93,9 @@ def main():
|
|||||||
|
|
||||||
if cfg.run_cfg.wandb_log:
|
if cfg.run_cfg.wandb_log:
|
||||||
wandb.login()
|
wandb.login()
|
||||||
wandb.init(project="minigptv2",name=args.job_name)
|
wandb.init(project="minigptv", name=cfg.run_cfg.job_name)
|
||||||
wandb.watch(model)
|
wandb.watch(model)
|
||||||
|
|
||||||
|
|
||||||
runner = get_runner_class(cfg)(
|
runner = get_runner_class(cfg)(
|
||||||
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
||||||
)
|
)
|
||||||
|
@ -53,3 +53,6 @@ run:
|
|||||||
world_size: 1
|
world_size: 1
|
||||||
dist_url: "env://"
|
dist_url: "env://"
|
||||||
distributed: True
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigpt4_llama2_pretrain
|
@ -47,3 +47,6 @@ run:
|
|||||||
world_size: 1
|
world_size: 1
|
||||||
dist_url: "env://"
|
dist_url: "env://"
|
||||||
distributed: True
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigpt4_llama2_finetune
|
@ -53,3 +53,6 @@ run:
|
|||||||
world_size: 1
|
world_size: 1
|
||||||
dist_url: "env://"
|
dist_url: "env://"
|
||||||
distributed: True
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigpt4_pretrain
|
@ -47,3 +47,6 @@ run:
|
|||||||
world_size: 1
|
world_size: 1
|
||||||
dist_url: "env://"
|
dist_url: "env://"
|
||||||
distributed: True
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigpt4_finetune
|
@ -276,7 +276,6 @@ run:
|
|||||||
init_lr: 1e-5
|
init_lr: 1e-5
|
||||||
min_lr: 8e-5
|
min_lr: 8e-5
|
||||||
warmup_lr: 1e-6
|
warmup_lr: 1e-6
|
||||||
wandb_log: True
|
|
||||||
|
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
max_epoch: 50
|
max_epoch: 50
|
||||||
@ -297,3 +296,6 @@ run:
|
|||||||
world_size: 1
|
world_size: 1
|
||||||
dist_url: "env://"
|
dist_url: "env://"
|
||||||
distributed: True
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigptv2_finetune
|
Loading…
Reference in New Issue
Block a user