import torch | |
from transformers import StoppingCriteria | |
class StopWordsCriteria(StoppingCriteria): | |
def __init__(self, stop_indices: list): | |
self.stop_indices = stop_indices | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# do not support batch inference | |
for i in range(len(self.stop_indices)): | |
if self.stop_indices[-1-i] != input_ids[0][-1-i]: | |
return False | |
return True | |