From ca2314c206e6a5bcfc17e97ea5ecc3167583afa4 Mon Sep 17 00:00:00 2001 From: ZhuDeyao Date: Tue, 31 Oct 2023 13:59:22 +0300 Subject: [PATCH] Update stop criteria --- minigpt4/conversation/conversation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 73a995e..83c8bee 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop):])).item(): + if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False