Merge pull request #389 from TsuTikgiau/main

fix the chat template and reduction error in v1 training
This commit is contained in:
ZhuDeyao 2023-10-21 16:36:09 +03:00 committed by GitHub
commit 2933ff4fee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -13,17 +13,17 @@ from omegaconf import OmegaConf
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BertTokenizer, LlamaTokenizer from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from peft import ( from peft import (
LoraConfig, LoraConfig,
get_peft_model, get_peft_model,
prepare_model_for_int8_training, prepare_model_for_int8_training,
) )
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from minigpt4.common.dist_utils import download_cached_file
from minigpt4.common.utils import get_abs_path, is_url from minigpt4.common.utils import get_abs_path, is_url
from minigpt4.models.eva_vit import create_eva_vit_g from minigpt4.models.eva_vit import create_eva_vit_g
from minigpt4.models.modeling_llama import LlamaForCausalLM

View File

@ -236,7 +236,7 @@ class MiniGPTBase(BaseModel):
else: else:
instruction = None instruction = None
if self.chat_template: if hasattr(self, 'chat_template') and self.chat_template:
instruction = [self.prompt_template.format(instruct) for instruct in instruction] instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples: if 'length' in samples: