fix: update peft imports

replace a deprecated function
"prepare_model_for_kbit_training"
This commit is contained in:
3luka 2024-08-16 12:23:28 +03:00
parent 71df7647b6
commit 5de8fce3bb

View File

@ -17,7 +17,7 @@ from transformers import LlamaTokenizer
from peft import ( from peft import (
LoraConfig, LoraConfig,
get_peft_model, get_peft_model,
prepare_model_for_int8_training, prepare_model_for_kbit_training,
) )
from minigpt4.common.dist_utils import download_cached_file from minigpt4.common.dist_utils import download_cached_file
@ -188,7 +188,7 @@ class BaseModel(nn.Module):
) )
if lora_r > 0: if lora_r > 0:
llama_model = prepare_model_for_int8_training(llama_model) llama_model = prepare_model_for_kbit_training(llama_model)
loraconfig = LoraConfig( loraconfig = LoraConfig(
r=lora_r, r=lora_r,
bias="none", bias="none",