mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
fix conflicts
This commit is contained in:
commit
d1527dd924
3
.gitignore
vendored
3
.gitignore
vendored
@ -158,3 +158,6 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
.checkpoints/
|
||||||
|
minigpt4/output/
|
||||||
|
37
arnold_before.sh
Normal file
37
arnold_before.sh
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# This script is for 1) Install dependancies; 2) Align internal cluster with standard practice
|
||||||
|
|
||||||
|
# Pip install
|
||||||
|
# export http_proxy=10.20.47.147:3128 https_proxy=10.20.47.147:3128 no_proxy=code.byted.org
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
|
||||||
|
mmengine-0.7.3
|
||||||
|
# unset http_proxy && unset https_proxy && unset no_proxy
|
||||||
|
|
||||||
|
# # ----------------------------------------------------------------------------------------
|
||||||
|
# # setup environment variables
|
||||||
|
# # disable TF verbose logging
|
||||||
|
# TF_CPP_MIN_LOG_LEVEL=2
|
||||||
|
# # fix known issues for pytorch-1.5.1 accroding to
|
||||||
|
# # https://blog.exxactcorp.com/pytorch-1-5-1-bug-fix-release/
|
||||||
|
# MKL_THREADING_LAYER=GNU
|
||||||
|
# # set NCCL envs for disributed communication
|
||||||
|
# NCCL_IB_GID_INDEX=3
|
||||||
|
# NCCL_IB_DISABLE=0
|
||||||
|
# NCCL_DEBUG=INFO
|
||||||
|
# ARNOLD_FRAMEWORK=pytorch
|
||||||
|
# # get distributed training parameters
|
||||||
|
# METIS_WORKER_0_HOST=${METIS_WORKER_0_HOST:-"127.0.0.1"}
|
||||||
|
# NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||||
|
# ARNOLD_WORKER_GPU=${ARNOLD_WORKER_GPU:-$NV_GPUS}
|
||||||
|
# ARNOLD_WORKER_NUM=${ARNOLD_WORKER_NUM:-1}
|
||||||
|
# ARNOLD_ID=${ARNOLD_ID:-0}
|
||||||
|
# ARNOLD_PORT=${METIS_WORKER_0_PORT:-3343}
|
||||||
|
|
||||||
|
|
||||||
|
# export NNODES=$ARNOLD_WORKER_NUM
|
||||||
|
# export NODE_RANK=$ARNOLD_ID
|
||||||
|
# export MASTER_ADDR=$METIS_WORKER_0_HOST
|
||||||
|
# export MASTER_PORT=$ARNOLD_PORT
|
||||||
|
# export GPUS=$ARNOLD_WORKER_GPU
|
35
dist_train.sh
Normal file
35
dist_train.sh
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -x
|
||||||
|
|
||||||
|
NNODES=${NNODES:-1}
|
||||||
|
NODE_RANK=${NODE_RANK:-0}
|
||||||
|
GPUS=${GPUS:-${ARNOLD_WORKER_GPU}}
|
||||||
|
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||||
|
MASTER_PORT=${MASTER_PORT:-9909}
|
||||||
|
# ARNOLD_WORKER_0_PORT
|
||||||
|
# ARNOLD_WORKER_0_ADDR
|
||||||
|
|
||||||
|
# settings for torch log
|
||||||
|
export BYTED_TORCH_FX=O1
|
||||||
|
export BYTED_TORCH_BYTECCL=O1
|
||||||
|
export TOKENIZERS_PARALLELISM=false
|
||||||
|
export HADOOP_ROOT_LOGGER=error,console
|
||||||
|
|
||||||
|
# settings for DDP multi-node for lab.pytorch image >= 1.13
|
||||||
|
export OMP_NUM_THREADS=8
|
||||||
|
export NCCL_IB_DISABLE=0
|
||||||
|
export NCCL_IB_GID_INDEX=3
|
||||||
|
export NCCL_SOCKET_IFNAME=eth0
|
||||||
|
export NCCL_SHM_DISABLE=0
|
||||||
|
|
||||||
|
# start training
|
||||||
|
CONFIG=$1
|
||||||
|
torchrun --nnodes=$NNODES \
|
||||||
|
--node_rank=$NODE_RANK \
|
||||||
|
--nproc_per_node=$GPUS \
|
||||||
|
--master_addr=$MASTER_ADDR \
|
||||||
|
--master_port=$MASTER_PORT \
|
||||||
|
train.py \
|
||||||
|
--cfg-path \
|
||||||
|
$CONFIG \
|
||||||
|
${@:2}
|
@ -33,7 +33,7 @@ from imagebind.models.multimodal_preprocessors import (
|
|||||||
TextPreprocessor,
|
TextPreprocessor,
|
||||||
ThermalPreprocessor,
|
ThermalPreprocessor,
|
||||||
)
|
)
|
||||||
from imagebind.models.multimodal_projectors import create_projectors, create_pre_projector
|
from imagebind.models.multimodal_projectors import create_projectors
|
||||||
|
|
||||||
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
|
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
|
||||||
|
|
||||||
@ -78,6 +78,7 @@ class ImageBindJoiner(nn.Module):
|
|||||||
):
|
):
|
||||||
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
|
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
|
||||||
freeze_qformer=vision_qformer_frozen,
|
freeze_qformer=vision_qformer_frozen,
|
||||||
|
encoder_width=1280, # TODO: fix hard-coding
|
||||||
q_former_model=vision_qformer_model)
|
q_former_model=vision_qformer_model)
|
||||||
modality_qformers = {
|
modality_qformers = {
|
||||||
ModalityType.VISION: vision_qformer
|
ModalityType.VISION: vision_qformer
|
||||||
|
5
minigpt4/configs/datasets/cc12m/defaults.yaml
Normal file
5
minigpt4/configs/datasets/cc12m/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
datasets:
|
||||||
|
cc12m:
|
||||||
|
data_type: images
|
||||||
|
build_info:
|
||||||
|
storage: /mnt/bn/zhicheng-dev-v6/dataset/cc12m_web/{000000..002221}.tar
|
@ -2,4 +2,4 @@ datasets:
|
|||||||
cc_sbu_align:
|
cc_sbu_align:
|
||||||
data_type: images
|
data_type: images
|
||||||
build_info:
|
build_info:
|
||||||
storage: /path/to/cc_sbu_align/
|
storage: /mnt/bn/bykang/chixma/data/fromMiniGPT4/cc_sbu_align
|
||||||
|
@ -5,11 +5,12 @@ model:
|
|||||||
freeze_imagebind: True
|
freeze_imagebind: True
|
||||||
|
|
||||||
# Q-Former
|
# Q-Former
|
||||||
freeze_qformer: False
|
freeze_qformer: True
|
||||||
|
q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth"
|
||||||
num_query_token: 32
|
num_query_token: 32
|
||||||
|
|
||||||
# Vicuna
|
# Vicuna
|
||||||
llama_model: "/path/to/vicuna/weights/"
|
llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-7b-v0/"
|
||||||
|
|
||||||
# generation configs
|
# generation configs
|
||||||
prompt: ""
|
prompt: ""
|
||||||
|
@ -10,10 +10,11 @@ model:
|
|||||||
freeze_qformer: True
|
freeze_qformer: True
|
||||||
|
|
||||||
# Q-Former
|
# Q-Former
|
||||||
|
q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth"
|
||||||
num_query_token: 32
|
num_query_token: 32
|
||||||
|
|
||||||
# Vicuna
|
# Vicuna
|
||||||
llama_model: "/path/to/vicuna/weights/"
|
llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/"
|
||||||
|
|
||||||
# generation configs
|
# generation configs
|
||||||
prompt: ""
|
prompt: ""
|
||||||
|
@ -5,7 +5,8 @@ import warnings
|
|||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||||
from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
|
from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
|
||||||
from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset
|
from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, \
|
||||||
|
CCSBUAlignDatasetImageImageCaptionDataset, CCDataset
|
||||||
|
|
||||||
|
|
||||||
@registry.register_builder("cc_sbu")
|
@registry.register_builder("cc_sbu")
|
||||||
@ -103,3 +104,35 @@ class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("cc12m")
|
||||||
|
class CC12MBuilder(ImageBaseDatasetBuilder):
|
||||||
|
train_dataset_cls = CCDataset
|
||||||
|
|
||||||
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc12m/defaults.yaml"}
|
||||||
|
|
||||||
|
def _download_ann(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _download_vis(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
self.build_processors()
|
||||||
|
|
||||||
|
build_info = self.config.build_info
|
||||||
|
|
||||||
|
datasets = dict()
|
||||||
|
split = "train"
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
# [NOTE] return inner_datasets (wds.DataPipeline)
|
||||||
|
dataset_cls = self.train_dataset_cls
|
||||||
|
datasets[split] = dataset_cls(
|
||||||
|
vis_processor=self.vis_processors[split],
|
||||||
|
text_processor=self.text_processors[split],
|
||||||
|
location=build_info.storage,
|
||||||
|
).inner_dataset
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
@ -45,3 +45,24 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
|||||||
"text_input": caption,
|
"text_input": caption,
|
||||||
"image_id": self.img_ids[ann["image_id"]],
|
"image_id": self.img_ids[ann["image_id"]],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CCDataset(BaseDataset):
|
||||||
|
def __init__(self, vis_processor, text_processor, location):
|
||||||
|
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||||
|
|
||||||
|
self.inner_dataset = wds.DataPipeline(
|
||||||
|
wds.ResampledShards(location),
|
||||||
|
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||||
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
|
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||||
|
wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
|
||||||
|
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||||
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, sample):
|
||||||
|
return {
|
||||||
|
"image": sample[0],
|
||||||
|
"text_input": sample[1],
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ import random
|
|||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
@ -65,6 +66,9 @@ class BindGPT4(BaseModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
print('Loading LLAMA Done')
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
|
# TODO: remove hard-coding
|
||||||
|
self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
|
||||||
|
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
|
|
||||||
@ -82,6 +86,8 @@ class BindGPT4(BaseModel):
|
|||||||
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
imagebind_outputs = self.multimodal_encoder(inputs)
|
imagebind_outputs = self.multimodal_encoder(inputs)
|
||||||
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
||||||
|
# NOTE: only accept image here
|
||||||
|
llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
|
||||||
return llama_inputs
|
return llama_inputs
|
||||||
|
|
||||||
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
||||||
@ -109,9 +115,13 @@ class BindGPT4(BaseModel):
|
|||||||
Only accept image inputs here.
|
Only accept image inputs here.
|
||||||
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
||||||
"""
|
"""
|
||||||
embeds = self.encode_inputs(inputs)
|
bind_inputs = {ModalityType.VISION: inputs['image']}
|
||||||
|
embeds = self.encode_inputs(bind_inputs)
|
||||||
# assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
|
# assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
|
||||||
|
if self.prompt_list:
|
||||||
prompt = random.choice(self.prompt_list)
|
prompt = random.choice(self.prompt_list)
|
||||||
|
else:
|
||||||
|
prompt = None
|
||||||
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
||||||
|
|
||||||
# NOTE: No modifications from the next line to the end. Except for the autocast part.
|
# NOTE: No modifications from the next line to the end. Except for the autocast part.
|
||||||
|
@ -28,7 +28,8 @@ from transformers import BertTokenizer
|
|||||||
class Blip2Base(BaseModel):
|
class Blip2Base(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_tokenizer(cls):
|
def init_tokenizer(cls):
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # NOTE: network issue
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("/mnt/bn/bykang/chixma/data/pretrained_models/bert-base-uncased")
|
||||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@ -44,7 +45,8 @@ class Blip2Base(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
||||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
# encoder_config = BertConfig.from_pretrained("bert-base-uncased") # NOTE: network issue
|
||||||
|
encoder_config = BertConfig.from_pretrained("/mnt/bn/bykang/chixma/data/pretrained_models/bert-base-uncased")
|
||||||
encoder_config.encoder_width = vision_width
|
encoder_config.encoder_width = vision_width
|
||||||
# insert cross-attention layer every other block
|
# insert cross-attention layer every other block
|
||||||
encoder_config.add_cross_attention = True
|
encoder_config.add_cross_attention = True
|
||||||
|
@ -6,11 +6,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
# from minigpt4.processors.blip_processors import (
|
from minigpt4.processors.blip_processors import (
|
||||||
# Blip2ImageTrainProcessor,
|
Blip2ImageTrainProcessor,
|
||||||
# Blip2ImageEvalProcessor,
|
Blip2ImageEvalProcessor,
|
||||||
# BlipCaptionProcessor,
|
BlipCaptionProcessor,
|
||||||
# )
|
)
|
||||||
from minigpt4.processors.imagebind_processor import (
|
from minigpt4.processors.imagebind_processor import (
|
||||||
ImageBindCaptionProcessor,
|
ImageBindCaptionProcessor,
|
||||||
ImageBindVisionTrainProcessor,
|
ImageBindVisionTrainProcessor,
|
||||||
@ -21,9 +21,9 @@ from minigpt4.common.registry import registry
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
# "Blip2ImageTrainProcessor",
|
"Blip2ImageTrainProcessor",
|
||||||
# "Blip2ImageEvalProcessor",
|
"Blip2ImageEvalProcessor",
|
||||||
# "BlipCaptionProcessor",
|
"BlipCaptionProcessor",
|
||||||
"ImageBindCaptionProcessor",
|
"ImageBindCaptionProcessor",
|
||||||
"ImageBindVisionTrainProcessor",
|
"ImageBindVisionTrainProcessor",
|
||||||
"ImageBindVisionEvalProcessor"
|
"ImageBindVisionEvalProcessor"
|
||||||
|
@ -88,7 +88,7 @@ class RunnerBase:
|
|||||||
if self.use_distributed:
|
if self.use_distributed:
|
||||||
if self._wrapped_model is None:
|
if self._wrapped_model is None:
|
||||||
self._wrapped_model = DDP(
|
self._wrapped_model = DDP(
|
||||||
self._model, device_ids=[self.config.run_cfg.gpu]
|
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._wrapped_model = self._model
|
self._wrapped_model = self._model
|
||||||
|
68
requirements.txt
Normal file
68
requirements.txt
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
-i https://bytedpypi.byted.org/simple/
|
||||||
|
|
||||||
|
accelerate==0.16.0
|
||||||
|
aiohttp==3.8.4
|
||||||
|
aiosignal==1.3.1
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrs==22.2.0
|
||||||
|
bitsandbytes==0.37.0
|
||||||
|
cchardet==2.1.7
|
||||||
|
chardet==5.1.0
|
||||||
|
contourpy==1.0.7
|
||||||
|
cycler==0.11.0
|
||||||
|
filelock==3.9.0
|
||||||
|
fonttools==4.38.0
|
||||||
|
frozenlist==1.3.3
|
||||||
|
huggingface-hub==0.13.4
|
||||||
|
importlib-resources==5.12.0
|
||||||
|
kiwisolver==1.4.4
|
||||||
|
matplotlib==3.7.0
|
||||||
|
multidict==6.0.4
|
||||||
|
openai==0.27.0
|
||||||
|
packaging==23.0
|
||||||
|
psutil==5.9.4
|
||||||
|
pycocotools==2.0.6
|
||||||
|
pyparsing==3.0.9
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pyyaml==6.0
|
||||||
|
regex==2022.10.31
|
||||||
|
tokenizers==0.13.2
|
||||||
|
tqdm==4.64.1
|
||||||
|
transformers==4.28.0
|
||||||
|
timm==0.6.13
|
||||||
|
spacy==3.5.1
|
||||||
|
webdataset==0.2.48
|
||||||
|
scikit-learn==1.2.2
|
||||||
|
scipy==1.10.1
|
||||||
|
yarl==1.8.2
|
||||||
|
zipp==3.14.0
|
||||||
|
omegaconf==2.3.0
|
||||||
|
opencv-python==4.7.0.72
|
||||||
|
iopath==0.1.10
|
||||||
|
decord==0.6.0
|
||||||
|
tenacity==8.2.2
|
||||||
|
peft
|
||||||
|
pycocoevalcap
|
||||||
|
sentence-transformers
|
||||||
|
umap-learn
|
||||||
|
notebook
|
||||||
|
gradio==3.24.1
|
||||||
|
gradio-client==0.0.8
|
||||||
|
wandb
|
||||||
|
ipdb
|
||||||
|
tensorflow-cpu
|
||||||
|
tensorboardX
|
||||||
|
# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||||
|
|
||||||
|
bs4==0.0.1 # Needed for text cleaning
|
||||||
|
bson==0.5.10
|
||||||
|
byted-dataloader==0.3.6
|
||||||
|
diffusers[torch]==0.16.1
|
||||||
|
einops==0.6.0
|
||||||
|
ftfy==6.1.1 # Needed for text cleaning
|
||||||
|
lpips==0.1.4
|
||||||
|
sentencepiece==0.1.99 # Needed for T5 tokenizer
|
||||||
|
tensorboard==2.11.2
|
||||||
|
tensorflow==2.11.0 # Needed for tensorboard hdfs support
|
||||||
|
tensorflow-io==0.30.0 # Needed for tensorboard hdfs support
|
||||||
|
tqdm==4.64.1
|
57
train_configs/bindgpt4_stage1_pretrain.yaml
Normal file
57
train_configs/bindgpt4_stage1_pretrain.yaml
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
model:
|
||||||
|
arch: bind_gpt4
|
||||||
|
model_type: pretrain_vicuna
|
||||||
|
freeze_vit: True
|
||||||
|
freeze_qformer: False
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc12m:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 115
|
||||||
|
# cc_sbu:
|
||||||
|
# vis_processor:
|
||||||
|
# train:
|
||||||
|
# name: "blip2_image_train"
|
||||||
|
# image_size: 224
|
||||||
|
# text_processor:
|
||||||
|
# train:
|
||||||
|
# name: "blip_caption"
|
||||||
|
# sample_ratio: 14
|
||||||
|
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 1e-4
|
||||||
|
min_lr: 8e-5
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 4
|
||||||
|
batch_size_train: 64
|
||||||
|
batch_size_eval: 64
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 5000
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "output/minigpt4_stage1_pretrain"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -15,15 +15,15 @@ datasets:
|
|||||||
train:
|
train:
|
||||||
name: "blip_caption"
|
name: "blip_caption"
|
||||||
sample_ratio: 115
|
sample_ratio: 115
|
||||||
cc_sbu:
|
# cc_sbu:
|
||||||
vis_processor:
|
# vis_processor:
|
||||||
train:
|
# train:
|
||||||
name: "blip2_image_train"
|
# name: "blip2_image_train"
|
||||||
image_size: 224
|
# image_size: 224
|
||||||
text_processor:
|
# text_processor:
|
||||||
train:
|
# train:
|
||||||
name: "blip_caption"
|
# name: "blip_caption"
|
||||||
sample_ratio: 14
|
# sample_ratio: 14
|
||||||
|
|
||||||
|
|
||||||
run:
|
run:
|
||||||
|
Loading…
Reference in New Issue
Block a user