diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 73a995e..83c8bee 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop):])).item(): + if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False