diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py index bc4c084..74626be 100644 --- a/minigpt4/models/__init__.py +++ b/minigpt4/models/__init__.py @@ -11,7 +11,7 @@ from omegaconf import OmegaConf from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel -from minigpt4.models.mini_gpt4 import MiniGPT4 +from minigpt4.models.minigpt_4 import MiniGPT4 from minigpt4.processors.base_processor import BaseProcessor diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/minigpt_4.py similarity index 95% rename from minigpt4/models/mini_gpt4.py rename to minigpt4/models/minigpt_4.py index 2fdf7da..f73a800 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/minigpt_4.py @@ -6,19 +6,10 @@ from torch.cuda.amp import autocast as autocast import torch.nn as nn from minigpt4.common.registry import registry -from minigpt4.models.base_model import BaseModel, disabled_train +from minigpt4.models.base_model import disabled_train from minigpt4.models.minigpt_base import MiniGPTBase from minigpt4.models.Qformer import BertConfig, BertLMHeadModel -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers import LlamaTokenizer -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - prepare_model_for_int8_training, - set_peft_model_state_dict, -) @registry.register_model("mini_gpt4") diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index 0a8c848..aa91bf3 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -6,15 +6,8 @@ from torch.cuda.amp import autocast as autocast import torch.nn as nn from minigpt4.common.registry import registry -from minigpt4.models.base_model import BaseModel, disabled_train -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers import LlamaTokenizer +from minigpt4.models.base_model import BaseModel -from peft import ( - LoraConfig, - get_peft_model, - prepare_model_for_int8_training, -) class MiniGPTBase(BaseModel):