mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
201 lines
5.6 KiB
Python
201 lines
5.6 KiB
Python
|
"""
|
||
|
Copyright (c) 2022, salesforce.com, inc.
|
||
|
All rights reserved.
|
||
|
SPDX-License-Identifier: BSD-3-Clause
|
||
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||
|
"""
|
||
|
|
||
|
import logging
|
||
|
import torch
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
from minigpt4.common.registry import registry
|
||
|
from minigpt4.models.base_model import BaseModel
|
||
|
from minigpt4.models.blip2 import Blip2Base
|
||
|
from minigpt4.models.mini_gpt4 import MiniGPT4
|
||
|
from minigpt4.processors.base_processor import BaseProcessor
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"load_model",
|
||
|
"BaseModel",
|
||
|
"Blip2Base",
|
||
|
"MiniGPT4",
|
||
|
]
|
||
|
|
||
|
|
||
|
def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
|
||
|
"""
|
||
|
Load supported models.
|
||
|
|
||
|
To list all available models and types in registry:
|
||
|
>>> from minigpt4.models import model_zoo
|
||
|
>>> print(model_zoo)
|
||
|
|
||
|
Args:
|
||
|
name (str): name of the model.
|
||
|
model_type (str): type of the model.
|
||
|
is_eval (bool): whether the model is in eval mode. Default: False.
|
||
|
device (str): device to use. Default: "cpu".
|
||
|
checkpoint (str): path or to checkpoint. Default: None.
|
||
|
Note that expecting the checkpoint to have the same keys in state_dict as the model.
|
||
|
|
||
|
Returns:
|
||
|
model (torch.nn.Module): model.
|
||
|
"""
|
||
|
|
||
|
model = registry.get_model_class(name).from_pretrained(model_type=model_type)
|
||
|
|
||
|
if checkpoint is not None:
|
||
|
model.load_checkpoint(checkpoint)
|
||
|
|
||
|
if is_eval:
|
||
|
model.eval()
|
||
|
|
||
|
if device == "cpu":
|
||
|
model = model.float()
|
||
|
|
||
|
return model.to(device)
|
||
|
|
||
|
|
||
|
def load_preprocess(config):
|
||
|
"""
|
||
|
Load preprocessor configs and construct preprocessors.
|
||
|
|
||
|
If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
|
||
|
|
||
|
Args:
|
||
|
config (dict): preprocessor configs.
|
||
|
|
||
|
Returns:
|
||
|
vis_processors (dict): preprocessors for visual inputs.
|
||
|
txt_processors (dict): preprocessors for text inputs.
|
||
|
|
||
|
Key is "train" or "eval" for processors used in training and evaluation respectively.
|
||
|
"""
|
||
|
|
||
|
def _build_proc_from_cfg(cfg):
|
||
|
return (
|
||
|
registry.get_processor_class(cfg.name).from_config(cfg)
|
||
|
if cfg is not None
|
||
|
else BaseProcessor()
|
||
|
)
|
||
|
|
||
|
vis_processors = dict()
|
||
|
txt_processors = dict()
|
||
|
|
||
|
vis_proc_cfg = config.get("vis_processor")
|
||
|
txt_proc_cfg = config.get("text_processor")
|
||
|
|
||
|
if vis_proc_cfg is not None:
|
||
|
vis_train_cfg = vis_proc_cfg.get("train")
|
||
|
vis_eval_cfg = vis_proc_cfg.get("eval")
|
||
|
else:
|
||
|
vis_train_cfg = None
|
||
|
vis_eval_cfg = None
|
||
|
|
||
|
vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
|
||
|
vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
|
||
|
|
||
|
if txt_proc_cfg is not None:
|
||
|
txt_train_cfg = txt_proc_cfg.get("train")
|
||
|
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||
|
else:
|
||
|
txt_train_cfg = None
|
||
|
txt_eval_cfg = None
|
||
|
|
||
|
txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
|
||
|
txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
|
||
|
|
||
|
return vis_processors, txt_processors
|
||
|
|
||
|
|
||
|
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
|
||
|
"""
|
||
|
Load model and its related preprocessors.
|
||
|
|
||
|
List all available models and types in registry:
|
||
|
>>> from minigpt4.models import model_zoo
|
||
|
>>> print(model_zoo)
|
||
|
|
||
|
Args:
|
||
|
name (str): name of the model.
|
||
|
model_type (str): type of the model.
|
||
|
is_eval (bool): whether the model is in eval mode. Default: False.
|
||
|
device (str): device to use. Default: "cpu".
|
||
|
|
||
|
Returns:
|
||
|
model (torch.nn.Module): model.
|
||
|
vis_processors (dict): preprocessors for visual inputs.
|
||
|
txt_processors (dict): preprocessors for text inputs.
|
||
|
"""
|
||
|
model_cls = registry.get_model_class(name)
|
||
|
|
||
|
# load model
|
||
|
model = model_cls.from_pretrained(model_type=model_type)
|
||
|
|
||
|
if is_eval:
|
||
|
model.eval()
|
||
|
|
||
|
# load preprocess
|
||
|
cfg = OmegaConf.load(model_cls.default_config_path(model_type))
|
||
|
if cfg is not None:
|
||
|
preprocess_cfg = cfg.preprocess
|
||
|
|
||
|
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
|
||
|
else:
|
||
|
vis_processors, txt_processors = None, None
|
||
|
logging.info(
|
||
|
f"""No default preprocess for model {name} ({model_type}).
|
||
|
This can happen if the model is not finetuned on downstream datasets,
|
||
|
or it is not intended for direct use without finetuning.
|
||
|
"""
|
||
|
)
|
||
|
|
||
|
if device == "cpu" or device == torch.device("cpu"):
|
||
|
model = model.float()
|
||
|
|
||
|
return model.to(device), vis_processors, txt_processors
|
||
|
|
||
|
|
||
|
class ModelZoo:
|
||
|
"""
|
||
|
A utility class to create string representation of available model architectures and types.
|
||
|
|
||
|
>>> from minigpt4.models import model_zoo
|
||
|
>>> # list all available models
|
||
|
>>> print(model_zoo)
|
||
|
>>> # show total number of models
|
||
|
>>> print(len(model_zoo))
|
||
|
"""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self.model_zoo = {
|
||
|
k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
|
||
|
for k, v in registry.mapping["model_name_mapping"].items()
|
||
|
}
|
||
|
|
||
|
def __str__(self) -> str:
|
||
|
return (
|
||
|
"=" * 50
|
||
|
+ "\n"
|
||
|
+ f"{'Architectures':<30} {'Types'}\n"
|
||
|
+ "=" * 50
|
||
|
+ "\n"
|
||
|
+ "\n".join(
|
||
|
[
|
||
|
f"{name:<30} {', '.join(types)}"
|
||
|
for name, types in self.model_zoo.items()
|
||
|
]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self.model_zoo.items())
|
||
|
|
||
|
def __len__(self):
|
||
|
return sum([len(v) for v in self.model_zoo.values()])
|
||
|
|
||
|
|
||
|
model_zoo = ModelZoo()
|