From 045a1d06029a6ba40d76d7dcb976f1fa7e090c4c Mon Sep 17 00:00:00 2001
From: Deyao Zhu <deyao.zhu@kaust.edu.sa>
Date: Thu, 12 Oct 2023 22:44:49 +0300
Subject: [PATCH] modularize minigpt4 code

---
 minigpt4/models/base_model.py   | 226 ++++++++++----------------------
 minigpt4/models/mini_gpt4.py    |  99 +++++++++++++-
 minigpt4/models/minigpt_base.py | 116 ++--------------
 3 files changed, 175 insertions(+), 266 deletions(-)

diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py
index fa21962..ae0a3be 100644
--- a/minigpt4/models/base_model.py
+++ b/minigpt4/models/base_model.py
@@ -5,19 +5,26 @@
  For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 """
 
-import contextlib
-import logging
 import os
+import logging
+import contextlib
 
+from omegaconf import OmegaConf
 import numpy as np
 import torch
 import torch.nn as nn
-from transformers import BertTokenizer
+from transformers import BertTokenizer, LlamaTokenizer
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from peft import (
+    LoraConfig,
+    get_peft_model,
+    prepare_model_for_int8_training,
+)
+
 from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
 from minigpt4.common.utils import get_abs_path, is_url
-from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
 from minigpt4.models.eva_vit import create_eva_vit_g
-from omegaconf import OmegaConf
+
 
 
 class BaseModel(nn.Module):
@@ -121,12 +128,6 @@ class BaseModel(nn.Module):
         else:
             return tot
 
-    @classmethod
-    def init_tokenizer(cls):
-        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
-        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
-        return tokenizer
-
     def maybe_autocast(self, dtype=torch.float16):
         # if on cpu, don't use autocast
         # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
@@ -137,33 +138,74 @@ class BaseModel(nn.Module):
         else:
             return contextlib.nullcontext()
 
-    @classmethod
-    def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
-        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
-        encoder_config.encoder_width = vision_width
-        # insert cross-attention layer every other block
-        encoder_config.add_cross_attention = True
-        encoder_config.cross_attention_freq = cross_attention_freq
-        encoder_config.query_length = num_query_token
-        Qformer = BertLMHeadModel(config=encoder_config)
-        query_tokens = nn.Parameter(
-            torch.zeros(1, num_query_token, encoder_config.hidden_size)
-        )
-        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
-        return Qformer, query_tokens
-
     @classmethod
     def init_vision_encoder(
-        cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
+        cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
     ):
+        logging.info('Loading VIT')
+
         assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
+        if not freeze:
+            precision = "fp32"  # fp16 is not for training
+
         visual_encoder = create_eva_vit_g(
             img_size, drop_path_rate, use_grad_checkpoint, precision
         )
 
         ln_vision = LayerNorm(visual_encoder.num_features)
+
+        if freeze:
+            for name, param in visual_encoder.named_parameters():
+                param.requires_grad = False
+            visual_encoder = visual_encoder.eval()
+            visual_encoder.train = disabled_train
+            for name, param in ln_vision.named_parameters():
+                param.requires_grad = False
+            ln_vision = ln_vision.eval()
+            ln_vision.train = disabled_train
+            logging.info("freeze vision encoder")
+
+        logging.info('Loading VIT Done')
         return visual_encoder, ln_vision
 
+    def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
+                 **lora_kargs):
+        logging.info('Loading LLAMA')
+        llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
+        llama_tokenizer.pad_token = "$$"
+
+        if low_resource:
+            llama_model = LlamaForCausalLM.from_pretrained(
+                llama_model_path,
+                torch_dtype=torch.float16,
+                load_in_8bit=True,
+                device_map={'': low_res_device}
+            )
+        else:
+            llama_model = LlamaForCausalLM.from_pretrained(
+                llama_model_path,
+                torch_dtype=torch.float16,
+            )
+
+        if lora_r > 0:
+            llama_model = prepare_model_for_int8_training(llama_model)
+            loraconfig = LoraConfig(
+                r=lora_r,
+                bias="none",
+                task_type="CAUSAL_LM",
+                **lora_kargs
+            )
+            llama_model = get_peft_model(llama_model, loraconfig)
+
+            llama_model.print_trainable_parameters()
+
+        else:
+            for name, param in llama_model.named_parameters():
+                param.requires_grad = False
+        logging.info('Loading LLAMA Done')
+        return llama_model, llama_tokenizer
+
+
     def load_from_pretrained(self, url_or_filename):
         if is_url(url_or_filename):
             cached_file = download_cached_file(
@@ -185,136 +227,6 @@ class BaseModel(nn.Module):
         return msg
 
 
-
-class BaseEncoder(nn.Module):
-    """
-    Base class for primitive encoders, such as ViT, TimeSformer, etc.
-    """
-
-    def __init__(self):
-        super().__init__()
-
-    def forward_features(self, samples, **kwargs):
-        raise NotImplementedError
-
-    @property
-    def device(self):
-        return list(self.parameters())[0].device
-
-
-class SharedQueueMixin:
-    @torch.no_grad()
-    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
-        # gather keys before updating queue
-        image_feats = concat_all_gather(image_feat)
-        text_feats = concat_all_gather(text_feat)
-
-        batch_size = image_feats.shape[0]
-
-        ptr = int(self.queue_ptr)
-        assert self.queue_size % batch_size == 0  # for simplicity
-
-        # replace the keys at ptr (dequeue and enqueue)
-        self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
-        self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
-
-        if idxs is not None:
-            idxs = concat_all_gather(idxs)
-            self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
-
-        ptr = (ptr + batch_size) % self.queue_size  # move pointer
-        self.queue_ptr[0] = ptr
-
-
-class MomentumDistilationMixin:
-    @torch.no_grad()
-    def copy_params(self):
-        for model_pair in self.model_pairs:
-            for param, param_m in zip(
-                model_pair[0].parameters(), model_pair[1].parameters()
-            ):
-                param_m.data.copy_(param.data)  # initialize
-                param_m.requires_grad = False  # not update by gradient
-
-    @torch.no_grad()
-    def _momentum_update(self):
-        for model_pair in self.model_pairs:
-            for param, param_m in zip(
-                model_pair[0].parameters(), model_pair[1].parameters()
-            ):
-                param_m.data = param_m.data * self.momentum + param.data * (
-                    1.0 - self.momentum
-                )
-
-
-class GatherLayer(torch.autograd.Function):
-    """
-    Gather tensors from all workers with support for backward propagation:
-    This implementation does not cut the gradients as torch.distributed.all_gather does.
-    """
-
-    @staticmethod
-    def forward(ctx, x):
-        output = [
-            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
-        ]
-        torch.distributed.all_gather(output, x)
-        return tuple(output)
-
-    @staticmethod
-    def backward(ctx, *grads):
-        all_gradients = torch.stack(grads)
-        torch.distributed.all_reduce(all_gradients)
-        return all_gradients[torch.distributed.get_rank()]
-
-
-def all_gather_with_grad(tensors):
-    """
-    Performs all_gather operation on the provided tensors.
-    Graph remains connected for backward grad computation.
-    """
-    # Queue the gathered tensors
-    world_size = torch.distributed.get_world_size()
-    # There is no need for reduction in the single-proc case
-    if world_size == 1:
-        return tensors
-
-    # tensor_all = GatherLayer.apply(tensors)
-    tensor_all = GatherLayer.apply(tensors)
-
-    return torch.cat(tensor_all, dim=0)
-
-
-@torch.no_grad()
-def concat_all_gather(tensor):
-    """
-    Performs all_gather operation on the provided tensors.
-    *** Warning ***: torch.distributed.all_gather has no gradient.
-    """
-    # if use distributed training
-    if not is_dist_avail_and_initialized():
-        return tensor
-
-    tensors_gather = [
-        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
-    ]
-    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
-
-    output = torch.cat(tensors_gather, dim=0)
-    return output
-
-
-def tile(x, dim, n_tile):
-    init_dim = x.size(dim)
-    repeat_idx = [1] * x.dim()
-    repeat_idx[dim] = n_tile
-    x = x.repeat(*(repeat_idx))
-    order_index = torch.LongTensor(
-        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
-    )
-    return torch.index_select(x, dim, order_index.to(x.device))
-
-
 def disabled_train(self, mode=True):
     """Overwrite model.train with this function to make sure train/eval mode
     does not change anymore."""
diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py
index d1ff071..2fdf7da 100644
--- a/minigpt4/models/mini_gpt4.py
+++ b/minigpt4/models/mini_gpt4.py
@@ -8,6 +8,7 @@ import torch.nn as nn
 from minigpt4.common.registry import registry
 from minigpt4.models.base_model import BaseModel, disabled_train
 from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
 from transformers.models.llama.modeling_llama import LlamaForCausalLM
 from transformers import LlamaTokenizer
 
@@ -31,6 +32,99 @@ class MiniGPT4(MiniGPTBase):
         "pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
     }
 
+    def __init__(
+            self,
+            vit_model="eva_clip_g",
+            q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+            img_size=224,
+            drop_path_rate=0,
+            use_grad_checkpoint=False,
+            vit_precision="fp16",
+            freeze_vit=True,
+            has_qformer=True,
+            freeze_qformer=True,
+            num_query_token=32,
+            llama_model="",
+            prompt_path="",
+            prompt_template="",
+            max_txt_len=32,
+            end_sym='\n',
+            low_resource=False,  # use 8 bit and put vit in cpu
+            device_8bit=0,  # the device of 8bit model should be set when loading and cannot be changed anymore.
+    ):
+        super().__init__(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            llama_model=llama_model,
+            max_txt_len=max_txt_len,
+            end_sym=end_sym,
+            low_resource=low_resource,
+            device_8bit=device_8bit,
+        )
+
+        self.has_qformer = has_qformer
+        if self.has_qformer:
+            print('Loading Q-Former')
+            self.Qformer, self.query_tokens = self.init_Qformer(
+                num_query_token, self.visual_encoder.num_features, freeze_qformer
+            )
+            self.load_from_pretrained(url_or_filename=q_former_model)  # load q-former weights here
+
+            img_f_dim = self.Qformer.config.hidden_size
+            print('Loading Q-Former Done')
+        else:
+            img_f_dim = self.visual_encoder.num_features * 4
+            print('Do not use Q-Former here.')
+
+        self.llama_proj = nn.Linear(
+            img_f_dim, self.llama_model.config.hidden_size
+        )
+
+        if prompt_path:
+            with open(prompt_path, 'r') as f:
+                raw_prompts = f.read().splitlines()
+            filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
+            self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
+            print('Load {} training prompts'.format(len(self.prompt_list)))
+            print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
+        else:
+            self.prompt_list = []
+
+    @classmethod
+    def init_Qformer(cls, num_query_token, vision_width, freeze):
+        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+        encoder_config.encoder_width = vision_width
+        # insert cross-attention layer every other block
+        encoder_config.add_cross_attention = True
+        encoder_config.cross_attention_freq = 2
+        encoder_config.query_length = num_query_token
+        Qformer = BertLMHeadModel(config=encoder_config)
+        query_tokens = nn.Parameter(
+            torch.zeros(1, num_query_token, encoder_config.hidden_size)
+        )
+        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+
+        Qformer.cls = None
+        Qformer.bert.embeddings.word_embeddings = None
+        Qformer.bert.embeddings.position_embeddings = None
+        for layer in Qformer.bert.encoder.layer:
+            layer.output = None
+            layer.intermediate = None
+
+        if freeze:
+            for name, param in Qformer.named_parameters():
+                param.requires_grad = False
+            Qformer = Qformer.eval()
+            Qformer.train = disabled_train
+            query_tokens.requires_grad = False
+            logging.info("freeze Qformer")
+
+        return Qformer, query_tokens
+
     def encode_img(self, image):
         device = image.device
 
@@ -82,9 +176,6 @@ class MiniGPT4(MiniGPTBase):
         max_txt_len = cfg.get("max_txt_len", 32)
         end_sym = cfg.get("end_sym", '\n')
 
-        lora_r = cfg.get("lora_r", 0)
-        lora_alpha = cfg.get("lora_alpha", 32)
-
         model = cls(
             vit_model=vit_model,
             q_former_model=q_former_model,
@@ -103,8 +194,6 @@ class MiniGPT4(MiniGPTBase):
             end_sym=end_sym,
             low_resource=low_resource,
             device_8bit=device_8bit,
-            lora_r=lora_r,
-            lora_alpha=lora_alpha,
         )
 
         ckpt_path = cfg.get("ckpt", "")  # load weights of MiniGPT-4
diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py
index 7c4008d..0a8c848 100644
--- a/minigpt4/models/minigpt_base.py
+++ b/minigpt4/models/minigpt_base.py
@@ -13,9 +13,7 @@ from transformers import LlamaTokenizer
 from peft import (
     LoraConfig,
     get_peft_model,
-    get_peft_model_state_dict,
     prepare_model_for_int8_training,
-    set_peft_model_state_dict,
 )
 
 
@@ -27,131 +25,41 @@ class MiniGPTBase(BaseModel):
     def __init__(
         self,
         vit_model="eva_clip_g",
-        q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
         img_size=224,
         drop_path_rate=0,
         use_grad_checkpoint=False,
         vit_precision="fp16",
         freeze_vit=True,
-        has_qformer=True,
-        freeze_qformer=True,
-        num_query_token=32,
         llama_model="",
-        prompt_path="",
-        prompt_template="",
         max_txt_len=32,
         end_sym='\n',
         low_resource=False,  # use 8 bit and put vit in cpu
         device_8bit=0,  # the device of 8bit model should be set when loading and cannot be changed anymore.
-        lora_r=0,
+        lora_r=0,  # lora_r  means lora is not used
         lora_target_modules=["q_proj", "v_proj"],
         lora_alpha=16,
         lora_dropout=0.05,
     ):
         super().__init__()
 
-        self.tokenizer = self.init_tokenizer()
-        self.low_resource = low_resource
+        self.llama_model, self.llama_tokenizer = self.init_llm(
+            llama_model_path=llama_model,
+            low_resource=low_resource,
+            low_res_device=device_8bit,
+            lora_r=lora_r,
+            lora_target_modules=lora_target_modules,
+            lora_alpha=lora_alpha,
+            lora_dropout=lora_dropout,
+        )
 
-        print('Loading VIT')
         self.visual_encoder, self.ln_vision = self.init_vision_encoder(
-            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
         )
-        if freeze_vit:
-            for name, param in self.visual_encoder.named_parameters():
-                param.requires_grad = False
-            self.visual_encoder = self.visual_encoder.eval()
-            self.visual_encoder.train = disabled_train
-            for name, param in self.ln_vision.named_parameters():
-                param.requires_grad = False
-            self.ln_vision = self.ln_vision.eval()
-            self.ln_vision.train = disabled_train
-            logging.info("freeze vision encoder")
-        print('Loading VIT Done')
 
-        self.has_qformer = has_qformer
-        if self.has_qformer:
-            print('Loading Q-Former')
-            self.Qformer, self.query_tokens = self.init_Qformer(
-                num_query_token, self.visual_encoder.num_features
-            )
-            self.Qformer.cls = None
-            self.Qformer.bert.embeddings.word_embeddings = None
-            self.Qformer.bert.embeddings.position_embeddings = None
-            for layer in self.Qformer.bert.encoder.layer:
-                layer.output = None
-                layer.intermediate = None
-            self.load_from_pretrained(url_or_filename=q_former_model)
-
-            if freeze_qformer:
-                for name, param in self.Qformer.named_parameters():
-                    param.requires_grad = False
-                self.Qformer = self.Qformer.eval()
-                self.Qformer.train = disabled_train
-                self.query_tokens.requires_grad = False
-                logging.info("freeze Qformer")
-
-            img_f_dim = self.Qformer.config.hidden_size
-            print('Loading Q-Former Done')
-        else:
-            img_f_dim = self.visual_encoder.num_features * 4
-            print('Do not use Q-Former here.')
-
-        print('Loading LLAMA')
-        self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
-        self.llama_tokenizer.pad_token = "$$"
-
-        if self.low_resource:
-            self.llama_model = LlamaForCausalLM.from_pretrained(
-                llama_model,
-                torch_dtype=torch.float16,
-                load_in_8bit=True,
-                device_map={'': device_8bit}
-            )
-        else:
-            self.llama_model = LlamaForCausalLM.from_pretrained(
-                llama_model,
-                torch_dtype=torch.float16,
-            )
-
-        if lora_r > 0:
-            self.llama_model = prepare_model_for_int8_training(self.llama_model)
-            loraconfig = LoraConfig(
-                r=lora_r,
-                lora_alpha=lora_alpha,
-                target_modules=lora_target_modules,
-                lora_dropout=lora_dropout,
-                bias="none",
-                task_type="CAUSAL_LM"
-            )
-            self.llama_model = get_peft_model(self.llama_model, loraconfig)
-
-            # if ckpt_path:
-            #     print('load the llm under lora')
-            #     ckpt = torch.load(ckpt_path)
-            #     set_peft_model_state_dict(self.llama_model,ckpt)
-            self.llama_model.print_trainable_parameters()
-
-        else:
-            for name, param in self.llama_model.named_parameters():
-                param.requires_grad = False
-        print('Loading LLAMA Done')
-
-        self.llama_proj = nn.Linear(
-            img_f_dim, self.llama_model.config.hidden_size
-        )
         self.max_txt_len = max_txt_len
         self.end_sym = end_sym
 
-        if prompt_path:
-            with open(prompt_path, 'r') as f:
-                raw_prompts = f.read().splitlines()
-            filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
-            self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
-            print('Load {} training prompts'.format(len(self.prompt_list)))
-            print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
-        else:
-            self.prompt_list = []
+        self.prompt_list = []
 
     def vit_to_cpu(self):
         self.ln_vision.to("cpu")