From b59616dfaecade392478237d5fa9fdc9debab99d Mon Sep 17 00:00:00 2001
From: bingyikang <bingyikang@bytedance.com>
Date: Mon, 22 May 2023 18:38:05 +0800
Subject: [PATCH] add training config for bindgpt

---
 .gitignore                                    |  5 +-
 arnold_before.sh                              | 37 ++++++++++
 dist_train.sh                                 | 35 ++++++++++
 imagebind/models/image_bind.py                |  3 +-
 minigpt4/configs/datasets/cc12m/defaults.yaml |  5 ++
 minigpt4/configs/datasets/cc_sbu/align.yaml   |  2 +-
 minigpt4/configs/models/bindgpt4.yaml         |  3 +-
 minigpt4/configs/models/minigpt4.yaml         |  3 +-
 .../builders/image_text_pair_builder.py       | 33 ++++++++-
 minigpt4/datasets/datasets/cc_sbu_dataset.py  | 23 ++++++-
 minigpt4/models/bind_gpt4.py                  | 14 +++-
 minigpt4/models/blip2.py                      |  6 +-
 minigpt4/processors/__init__.py               | 16 ++---
 minigpt4/runners/runner_base.py               |  2 +-
 requirements.txt                              | 68 +++++++++++++++++++
 train_configs/bindgpt4_stage1_pretrain.yaml   | 57 ++++++++++++++++
 train_configs/minigpt4_stage1_pretrain.yaml   | 18 ++---
 17 files changed, 301 insertions(+), 29 deletions(-)
 create mode 100644 arnold_before.sh
 create mode 100644 dist_train.sh
 create mode 100644 minigpt4/configs/datasets/cc12m/defaults.yaml
 create mode 100644 requirements.txt
 create mode 100644 train_configs/bindgpt4_stage1_pretrain.yaml

diff --git a/.gitignore b/.gitignore
index b0b6f3a..b049b6b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -157,4 +157,7 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-.idea/
\ No newline at end of file
+.idea/
+
+.checkpoints/
+minigpt4/output/
diff --git a/arnold_before.sh b/arnold_before.sh
new file mode 100644
index 0000000..844651a
--- /dev/null
+++ b/arnold_before.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+# This script is for 1) Install dependancies; 2) Align internal cluster with standard practice
+
+# Pip install 
+# export http_proxy=10.20.47.147:3128  https_proxy=10.20.47.147:3128 no_proxy=code.byted.org
+pip3 install --upgrade pip
+pip3 install -r requirements.txt
+pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
+mmengine-0.7.3
+# unset http_proxy && unset https_proxy && unset no_proxy
+
+# # ----------------------------------------------------------------------------------------
+# # setup environment variables
+# # disable TF verbose logging
+# TF_CPP_MIN_LOG_LEVEL=2
+# # fix known issues for pytorch-1.5.1 accroding to 
+# # https://blog.exxactcorp.com/pytorch-1-5-1-bug-fix-release/
+# MKL_THREADING_LAYER=GNU
+# # set NCCL envs for disributed communication
+# NCCL_IB_GID_INDEX=3
+# NCCL_IB_DISABLE=0
+# NCCL_DEBUG=INFO
+# ARNOLD_FRAMEWORK=pytorch
+# # get distributed training parameters 
+# METIS_WORKER_0_HOST=${METIS_WORKER_0_HOST:-"127.0.0.1"}
+# NV_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
+# ARNOLD_WORKER_GPU=${ARNOLD_WORKER_GPU:-$NV_GPUS}
+# ARNOLD_WORKER_NUM=${ARNOLD_WORKER_NUM:-1}
+# ARNOLD_ID=${ARNOLD_ID:-0}
+# ARNOLD_PORT=${METIS_WORKER_0_PORT:-3343}
+
+
+# export NNODES=$ARNOLD_WORKER_NUM
+# export NODE_RANK=$ARNOLD_ID
+# export MASTER_ADDR=$METIS_WORKER_0_HOST
+# export MASTER_PORT=$ARNOLD_PORT
+# export GPUS=$ARNOLD_WORKER_GPU
diff --git a/dist_train.sh b/dist_train.sh
new file mode 100644
index 0000000..c9deccc
--- /dev/null
+++ b/dist_train.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+set -x
+
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+GPUS=${GPUS:-${ARNOLD_WORKER_GPU}}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+MASTER_PORT=${MASTER_PORT:-9909}
+# ARNOLD_WORKER_0_PORT
+# ARNOLD_WORKER_0_ADDR
+
+# settings for torch log
+export BYTED_TORCH_FX=O1
+export BYTED_TORCH_BYTECCL=O1
+export TOKENIZERS_PARALLELISM=false
+export HADOOP_ROOT_LOGGER=error,console
+
+# settings for DDP multi-node for lab.pytorch image >= 1.13
+export OMP_NUM_THREADS=8
+export NCCL_IB_DISABLE=0
+export NCCL_IB_GID_INDEX=3
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_SHM_DISABLE=0
+
+# start training
+CONFIG=$1
+torchrun --nnodes=$NNODES \
+         --node_rank=$NODE_RANK \
+         --nproc_per_node=$GPUS \
+         --master_addr=$MASTER_ADDR \
+         --master_port=$MASTER_PORT \
+         train.py \
+         --cfg-path \
+         $CONFIG \
+         ${@:2}
diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py
index cbe262e..c93fcd0 100644
--- a/imagebind/models/image_bind.py
+++ b/imagebind/models/image_bind.py
@@ -33,7 +33,7 @@ from imagebind.models.multimodal_preprocessors import (
     TextPreprocessor,
     ThermalPreprocessor,
 )
-from imagebind.models.multimodal_projectors import create_projectors, create_pre_projector
+from imagebind.models.multimodal_projectors import create_projectors
 
 from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
 
@@ -78,6 +78,7 @@ class ImageBindJoiner(nn.Module):
                                   ):
         vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
                                                 freeze_qformer=vision_qformer_frozen,
+                                                encoder_width=1280,  # TODO: fix hard-coding
                                                 q_former_model=vision_qformer_model)
         modality_qformers = {
             ModalityType.VISION: vision_qformer
diff --git a/minigpt4/configs/datasets/cc12m/defaults.yaml b/minigpt4/configs/datasets/cc12m/defaults.yaml
new file mode 100644
index 0000000..2f59614
--- /dev/null
+++ b/minigpt4/configs/datasets/cc12m/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+  cc12m:
+    data_type: images
+    build_info:
+      storage: /mnt/bn/zhicheng-dev-v6/dataset/cc12m_web/{000000..002221}.tar
diff --git a/minigpt4/configs/datasets/cc_sbu/align.yaml b/minigpt4/configs/datasets/cc_sbu/align.yaml
index 5710834..f30de07 100644
--- a/minigpt4/configs/datasets/cc_sbu/align.yaml
+++ b/minigpt4/configs/datasets/cc_sbu/align.yaml
@@ -2,4 +2,4 @@ datasets:
   cc_sbu_align:
     data_type: images
     build_info:
-      storage: /path/to/cc_sbu_align/
+      storage: /mnt/bn/bykang/chixma/data/fromMiniGPT4/cc_sbu_align
diff --git a/minigpt4/configs/models/bindgpt4.yaml b/minigpt4/configs/models/bindgpt4.yaml
index 3d436ec..8cdfe40 100644
--- a/minigpt4/configs/models/bindgpt4.yaml
+++ b/minigpt4/configs/models/bindgpt4.yaml
@@ -6,10 +6,11 @@ model:
 
   # Q-Former
   freeze_qformer: True
+  q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth"
   num_query_token: 32
 
   # Vicuna
-  llama_model: "/path/to/vicuna/weights/"
+  llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-7b-v0/"
 
   # generation configs
   prompt: ""
diff --git a/minigpt4/configs/models/minigpt4.yaml b/minigpt4/configs/models/minigpt4.yaml
index 87af448..9e7f8e9 100644
--- a/minigpt4/configs/models/minigpt4.yaml
+++ b/minigpt4/configs/models/minigpt4.yaml
@@ -10,10 +10,11 @@ model:
   freeze_qformer: True
 
   # Q-Former
+  q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth"
   num_query_token: 32
 
   # Vicuna
-  llama_model: "/path/to/vicuna/weights/"
+  llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/"
 
   # generation configs
   prompt: ""
diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py
index e5d66b8..a6b3253 100644
--- a/minigpt4/datasets/builders/image_text_pair_builder.py
+++ b/minigpt4/datasets/builders/image_text_pair_builder.py
@@ -5,7 +5,7 @@ import warnings
 from minigpt4.common.registry import registry
 from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
 from minigpt4.datasets.datasets.laion_dataset import LaionDataset
-from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
+from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset, CCDataset
 
 
 @registry.register_builder("cc_sbu")
@@ -103,3 +103,34 @@ class CCSBUAlignBuilder(BaseDatasetBuilder):
         )
 
         return datasets
+
+@registry.register_builder("cc12m")
+class CC12MBuilder(BaseDatasetBuilder):
+    train_dataset_cls = CCDataset
+
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/cc12m/defaults.yaml"}
+
+    def _download_ann(self):
+        pass
+
+    def _download_vis(self):
+        pass
+
+    def build(self):
+        self.build_processors()
+
+        build_info = self.config.build_info
+
+        datasets = dict()
+        split = "train"
+
+        # create datasets
+        # [NOTE] return inner_datasets (wds.DataPipeline)
+        dataset_cls = self.train_dataset_cls
+        datasets[split] = dataset_cls(
+            vis_processor=self.vis_processors[split],
+            text_processor=self.text_processors[split],
+            location=build_info.storage,
+        ).inner_dataset
+
+        return datasets
diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/cc_sbu_dataset.py
index f42bbce..43bcb55 100644
--- a/minigpt4/datasets/datasets/cc_sbu_dataset.py
+++ b/minigpt4/datasets/datasets/cc_sbu_dataset.py
@@ -44,4 +44,25 @@ class CCSBUAlignDataset(CaptionDataset):
             "image": image,
             "text_input": caption,
             "image_id": self.img_ids[ann["image_id"]],
-        }
\ No newline at end of file
+        }
+
+
+class CCDataset(BaseDataset):
+    def __init__(self, vis_processor, text_processor, location):
+        super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+
+        self.inner_dataset = wds.DataPipeline(
+            wds.ResampledShards(location),
+            wds.tarfile_to_samples(handler=wds.warn_and_continue),
+            wds.shuffle(1000, handler=wds.warn_and_continue),
+            wds.decode("pilrgb", handler=wds.warn_and_continue),
+            wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
+            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+            wds.map(self.to_dict, handler=wds.warn_and_continue),
+        )
+
+    def to_dict(self, sample):
+        return {
+            "image": sample[0],
+            "text_input": sample[1],
+        }
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
index 6234c73..5ec79d7 100644
--- a/minigpt4/models/bind_gpt4.py
+++ b/minigpt4/models/bind_gpt4.py
@@ -2,6 +2,7 @@ import random
 from typing import Dict, Tuple
 
 import torch
+import torch.nn as nn
 from torch import Tensor
 from transformers import LlamaTokenizer
 
@@ -65,6 +66,9 @@ class BindGPT4(BaseModel):
             param.requires_grad = False
         print('Loading LLAMA Done')
 
+        # TODO: remove hard-coding
+        self.llama_proj = nn.Linear(768, self.llama_model.config.hidden_size)
+
         self.max_txt_len = max_txt_len
         self.end_sym = end_sym
 
@@ -82,6 +86,8 @@ class BindGPT4(BaseModel):
     def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
         imagebind_outputs = self.multimodal_encoder(inputs)
         llama_inputs = self.multimodal_joiner(imagebind_outputs)
+        # NOTE: only accept image here
+        llama_inputs[ModalityType.VISION] = self.llama_proj(llama_inputs[ModalityType.VISION])
         return llama_inputs
 
     def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
@@ -109,9 +115,13 @@ class BindGPT4(BaseModel):
             Only accept image inputs here.
             Other modalities will conflict with the pre-defined prompt and wrapping strategy.
         """
-        embeds = self.encode_inputs(inputs)
+        bind_inputs = {ModalityType.VISION: inputs['image']}
+        embeds = self.encode_inputs(bind_inputs)
         # assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
-        prompt = random.choice(self.prompt_list)
+        if self.prompt_list:
+            prompt = random.choice(self.prompt_list)
+        else:
+            prompt = None
         img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
 
         # NOTE: No modifications from the next line to the end. Except for the autocast part.
diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py
index ee4a9dc..e33bf82 100644
--- a/minigpt4/models/blip2.py
+++ b/minigpt4/models/blip2.py
@@ -28,7 +28,8 @@ from transformers import BertTokenizer
 class Blip2Base(BaseModel):
     @classmethod
     def init_tokenizer(cls):
-        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")  # NOTE: network issue
+        tokenizer = BertTokenizer.from_pretrained("/mnt/bn/bykang/chixma/data/pretrained_models/bert-base-uncased")
         tokenizer.add_special_tokens({"bos_token": "[DEC]"})
         return tokenizer
 
@@ -44,7 +45,8 @@ class Blip2Base(BaseModel):
 
     @classmethod
     def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
-        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+        # encoder_config = BertConfig.from_pretrained("bert-base-uncased")  # NOTE: network issue
+        encoder_config = BertConfig.from_pretrained("/mnt/bn/bykang/chixma/data/pretrained_models/bert-base-uncased")
         encoder_config.encoder_width = vision_width
         # insert cross-attention layer every other block
         encoder_config.add_cross_attention = True
diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py
index 3d12f3b..0ce174e 100644
--- a/minigpt4/processors/__init__.py
+++ b/minigpt4/processors/__init__.py
@@ -6,11 +6,11 @@
 """
 
 from minigpt4.processors.base_processor import BaseProcessor
-# from minigpt4.processors.blip_processors import (
-#     Blip2ImageTrainProcessor,
-#     Blip2ImageEvalProcessor,
-#     BlipCaptionProcessor,
-# )
+from minigpt4.processors.blip_processors import (
+    Blip2ImageTrainProcessor,
+    Blip2ImageEvalProcessor,
+    BlipCaptionProcessor,
+)
 from minigpt4.processors.imagebind_processor import (
     ImageBindCaptionProcessor,
     ImageBindVisionTrainProcessor,
@@ -21,9 +21,9 @@ from minigpt4.common.registry import registry
 
 __all__ = [
     "BaseProcessor",
-    # "Blip2ImageTrainProcessor",
-    # "Blip2ImageEvalProcessor",
-    # "BlipCaptionProcessor",
+    "Blip2ImageTrainProcessor",
+    "Blip2ImageEvalProcessor",
+    "BlipCaptionProcessor",
     "ImageBindCaptionProcessor",
     "ImageBindVisionTrainProcessor",
     "ImageBindVisionEvalProcessor"
diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py
index ccb5706..5b67307 100644
--- a/minigpt4/runners/runner_base.py
+++ b/minigpt4/runners/runner_base.py
@@ -88,7 +88,7 @@ class RunnerBase:
             if self.use_distributed:
                 if self._wrapped_model is None:
                     self._wrapped_model = DDP(
-                        self._model, device_ids=[self.config.run_cfg.gpu]
+                        self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
                     )
             else:
                 self._wrapped_model = self._model
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..3125924
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,68 @@
+-i https://bytedpypi.byted.org/simple/
+
+accelerate==0.16.0
+aiohttp==3.8.4
+aiosignal==1.3.1
+async-timeout==4.0.2
+attrs==22.2.0
+bitsandbytes==0.37.0
+cchardet==2.1.7
+chardet==5.1.0
+contourpy==1.0.7
+cycler==0.11.0
+filelock==3.9.0
+fonttools==4.38.0
+frozenlist==1.3.3
+huggingface-hub==0.13.4
+importlib-resources==5.12.0
+kiwisolver==1.4.4
+matplotlib==3.7.0
+multidict==6.0.4
+openai==0.27.0
+packaging==23.0
+psutil==5.9.4
+pycocotools==2.0.6
+pyparsing==3.0.9
+python-dateutil==2.8.2
+pyyaml==6.0
+regex==2022.10.31
+tokenizers==0.13.2
+tqdm==4.64.1
+transformers==4.28.0
+timm==0.6.13
+spacy==3.5.1
+webdataset==0.2.48
+scikit-learn==1.2.2
+scipy==1.10.1
+yarl==1.8.2
+zipp==3.14.0
+omegaconf==2.3.0
+opencv-python==4.7.0.72
+iopath==0.1.10
+decord==0.6.0
+tenacity==8.2.2
+peft
+pycocoevalcap
+sentence-transformers
+umap-learn
+notebook
+gradio==3.24.1
+gradio-client==0.0.8
+wandb
+ipdb
+tensorflow-cpu
+tensorboardX
+# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
+
+bs4==0.0.1              # Needed for text cleaning
+bson==0.5.10
+byted-dataloader==0.3.6
+diffusers[torch]==0.16.1
+einops==0.6.0
+ftfy==6.1.1             # Needed for text cleaning
+lpips==0.1.4
+sentencepiece==0.1.99   # Needed for T5 tokenizer
+tensorboard==2.11.2
+tensorflow==2.11.0      # Needed for tensorboard hdfs support
+tensorflow-io==0.30.0   # Needed for tensorboard hdfs support
+tqdm==4.64.1
\ No newline at end of file
diff --git a/train_configs/bindgpt4_stage1_pretrain.yaml b/train_configs/bindgpt4_stage1_pretrain.yaml
new file mode 100644
index 0000000..5f8d0f6
--- /dev/null
+++ b/train_configs/bindgpt4_stage1_pretrain.yaml
@@ -0,0 +1,57 @@
+model:
+  arch: bind_gpt4
+  model_type: pretrain_vicuna
+  freeze_vit: True
+  freeze_qformer: False
+
+
+datasets:
+  cc12m:
+    vis_processor:
+      train:
+        name: "blip2_image_train"
+        image_size: 224
+    text_processor:
+      train:
+        name: "blip_caption"
+    sample_ratio: 115
+  # cc_sbu:
+  #   vis_processor:
+  #       train:
+  #         name: "blip2_image_train"
+  #         image_size: 224
+  #   text_processor:
+  #       train:
+  #         name: "blip_caption"
+  #   sample_ratio: 14
+
+
+run:
+  task: image_text_pretrain
+  # optimizer
+  lr_sched: "linear_warmup_cosine_lr"
+  init_lr: 1e-4
+  min_lr: 8e-5
+  warmup_lr: 1e-6
+
+  weight_decay: 0.05
+  max_epoch: 4
+  batch_size_train: 64
+  batch_size_eval: 64
+  num_workers: 4
+  warmup_steps: 5000
+  iters_per_epoch: 5000
+
+  seed: 42
+  output_dir: "output/minigpt4_stage1_pretrain"
+
+  amp: True
+  resume_ckpt_path: null
+
+  evaluate: False 
+  train_splits: ["train"]
+
+  device: "cuda"
+  world_size: 1
+  dist_url: "env://"
+  distributed: True
\ No newline at end of file
diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml
index 044246c..d2224ca 100644
--- a/train_configs/minigpt4_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_stage1_pretrain.yaml
@@ -15,15 +15,15 @@ datasets:
       train:
         name: "blip_caption"
     sample_ratio: 115
-  cc_sbu:
-    vis_processor:
-        train:
-          name: "blip2_image_train"
-          image_size: 224
-    text_processor:
-        train:
-          name: "blip_caption"
-    sample_ratio: 14
+  # cc_sbu:
+  #   vis_processor:
+  #       train:
+  #         name: "blip2_image_train"
+  #         image_size: 224
+  #   text_processor:
+  #       train:
+  #         name: "blip_caption"
+  #   sample_ratio: 14
 
 
 run: