Merge pull request #1 from TsuTikgiau/main

update the conv template
This commit is contained in:
Jun Chen 2023-10-22 21:08:37 -07:00 committed by GitHub
commit 102cc53d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -172,12 +172,12 @@ class MiniGPTBase(BaseModel):
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
questions = [self.llama_tokenizer(q,
questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
answers = [self.llama_tokenizer(q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in answers]
answers = [self.llama_tokenizer(a + self.end_sym,
return_tensors="pt",
add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):