diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index d70ca18..e47d38e 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -17,7 +17,7 @@ from transformers import LlamaTokenizer from peft import ( LoraConfig, get_peft_model, - prepare_model_for_int8_training, + prepare_model_for_kbit_training, ) from minigpt4.common.dist_utils import download_cached_file @@ -188,7 +188,7 @@ class BaseModel(nn.Module): ) if lora_r > 0: - llama_model = prepare_model_for_int8_training(llama_model) + llama_model = prepare_model_for_kbit_training(llama_model) loraconfig = LoraConfig( r=lora_r, bias="none",