diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index fd1d636..d70ca18 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -13,17 +13,17 @@ from omegaconf import OmegaConf import numpy as np import torch import torch.nn as nn -from transformers import BertTokenizer, LlamaTokenizer -from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers import LlamaTokenizer from peft import ( LoraConfig, get_peft_model, prepare_model_for_int8_training, ) -from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from minigpt4.common.dist_utils import download_cached_file from minigpt4.common.utils import get_abs_path, is_url from minigpt4.models.eva_vit import create_eva_vit_g +from minigpt4.models.modeling_llama import LlamaForCausalLM diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index c24c579..58edb1a 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -236,7 +236,7 @@ class MiniGPTBase(BaseModel): else: instruction = None - if self.chat_template: + if hasattr(self, 'chat_template') and self.chat_template: instruction = [self.prompt_template.format(instruct) for instruct in instruction] if 'length' in samples: