This commit is contained in:
junchen14 2023-10-23 07:08:42 +03:00
commit 5c3ec8bb73

View File

@ -172,12 +172,12 @@ class MiniGPTBase(BaseModel):
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
questions = [self.llama_tokenizer(q,
questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
answers = [self.llama_tokenizer(q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in answers]
answers = [self.llama_tokenizer(a + self.end_sym,
return_tensors="pt",
add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):