mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
Update stop criteria
This commit is contained in:
parent
9d41619968
commit
ca2314c206
@ -101,7 +101,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
for stop in self.stops:
|
for stop in self.stops:
|
||||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
if torch.all(input_ids[:, -len(stop):] == stop).item():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
Loading…
Reference in New Issue
Block a user