pit / utils.py
mu123567's picture
Upload 9 files
71e7434
import torch
from transformers import StoppingCriteria
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_indices: list):
self.stop_indices = stop_indices
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# do not support batch inference
for i in range(len(self.stop_indices)):
if self.stop_indices[-1-i] != input_ids[0][-1-i]:
return False
return True