Swarms / swarms /tools /logits_processor.py
harshalmore31's picture
Synced repo using 'sync_with_huggingface' Github Action
d8d14f1 verified
from swarms.utils.auto_download_check_packages import (
auto_check_and_download_package,
)
try:
import torch
except ImportError:
auto_check_and_download_package(
"torch", package_manager="pip", upgrade=True
)
import torch
try:
import transformers
except ImportError:
auto_check_and_download_package(
"transformers", package_manager="pip", upgrade=True
)
import transformers
class StringStoppingCriteria(transformers.StoppingCriteria):
def __init__(
self, tokenizer: transformers.PreTrainedTokenizer, prompt_length: int # type: ignore
):
self.tokenizer = tokenizer
self.prompt_length = prompt_length
def __call__(
self,
input_ids: torch.LongTensor, # type: ignore
_,
) -> bool:
if len(input_ids[0]) <= self.prompt_length:
return False
last_token_id = input_ids[0][-1]
last_token = self.tokenizer.decode(
last_token_id, skip_special_tokens=True
)
result = '"' in last_token
return result
class NumberStoppingCriteria(transformers.StoppingCriteria):
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer, # type: ignore
prompt_length: int,
precision: int = 3,
):
self.tokenizer = tokenizer
self.precision = precision
self.prompt_length = prompt_length
def __call__(
self,
input_ids: torch.LongTensor, # type: ignore
scores: torch.FloatTensor, # type: ignore
) -> bool:
decoded = self.tokenizer.decode(
input_ids[0][self.prompt_length :],
skip_special_tokens=True,
)
if decoded.count(".") > 1:
return True
if (
decoded.count(".") == 1
and len(decoded.strip().split(".")[1]) > self.precision
):
return True
if (
len(decoded) > 1
and any(c.isdigit() for c in decoded)
and decoded[-1] in [" ", "\n"]
):
return True
return False
class OutputNumbersTokens(transformers.LogitsWarper):
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, prompt: str): # type: ignore
self.tokenizer = tokenizer
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
vocab_size = len(tokenizer)
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
for _, token_id in tokenizer.get_vocab().items():
token_str = tokenizer.decode(token_id).strip()
if token_str == "" or (
all(c.isdigit() or c == "." for c in token_str)
and token_str.count(".") <= 1
):
self.allowed_mask[token_id] = True
def __call__(self, _, scores):
mask = self.allowed_mask.expand_as(scores)
scores[~mask] = -float("inf")
return scores