Update stop criteria

This commit is contained in:
ZhuDeyao 2023-10-31 13:59:22 +03:00 committed by GitHub
parent 9d41619968
commit ca2314c206
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops: for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item(): if torch.all(input_ids[:, -len(stop):] == stop).item():
return True return True
return False return False