Spaces:
Running
Running
from swarms.utils.auto_download_check_packages import ( | |
auto_check_and_download_package, | |
) | |
try: | |
import torch | |
except ImportError: | |
auto_check_and_download_package( | |
"torch", package_manager="pip", upgrade=True | |
) | |
import torch | |
try: | |
import transformers | |
except ImportError: | |
auto_check_and_download_package( | |
"transformers", package_manager="pip", upgrade=True | |
) | |
import transformers | |
class StringStoppingCriteria(transformers.StoppingCriteria): | |
def __init__( | |
self, tokenizer: transformers.PreTrainedTokenizer, prompt_length: int # type: ignore | |
): | |
self.tokenizer = tokenizer | |
self.prompt_length = prompt_length | |
def __call__( | |
self, | |
input_ids: torch.LongTensor, # type: ignore | |
_, | |
) -> bool: | |
if len(input_ids[0]) <= self.prompt_length: | |
return False | |
last_token_id = input_ids[0][-1] | |
last_token = self.tokenizer.decode( | |
last_token_id, skip_special_tokens=True | |
) | |
result = '"' in last_token | |
return result | |
class NumberStoppingCriteria(transformers.StoppingCriteria): | |
def __init__( | |
self, | |
tokenizer: transformers.PreTrainedTokenizer, # type: ignore | |
prompt_length: int, | |
precision: int = 3, | |
): | |
self.tokenizer = tokenizer | |
self.precision = precision | |
self.prompt_length = prompt_length | |
def __call__( | |
self, | |
input_ids: torch.LongTensor, # type: ignore | |
scores: torch.FloatTensor, # type: ignore | |
) -> bool: | |
decoded = self.tokenizer.decode( | |
input_ids[0][self.prompt_length :], | |
skip_special_tokens=True, | |
) | |
if decoded.count(".") > 1: | |
return True | |
if ( | |
decoded.count(".") == 1 | |
and len(decoded.strip().split(".")[1]) > self.precision | |
): | |
return True | |
if ( | |
len(decoded) > 1 | |
and any(c.isdigit() for c in decoded) | |
and decoded[-1] in [" ", "\n"] | |
): | |
return True | |
return False | |
class OutputNumbersTokens(transformers.LogitsWarper): | |
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, prompt: str): # type: ignore | |
self.tokenizer = tokenizer | |
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") | |
vocab_size = len(tokenizer) | |
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) | |
for _, token_id in tokenizer.get_vocab().items(): | |
token_str = tokenizer.decode(token_id).strip() | |
if token_str == "" or ( | |
all(c.isdigit() or c == "." for c in token_str) | |
and token_str.count(".") <= 1 | |
): | |
self.allowed_mask[token_id] = True | |
def __call__(self, _, scores): | |
mask = self.allowed_mask.expand_as(scores) | |
scores[~mask] = -float("inf") | |
return scores | |