import torch from .char_tokenizer import CharTokenizer from transformers import AutoModelForTokenClassification class AccentModel: def __init__(self, allow_cuda=True) -> None: self.device = torch.device('cuda' if torch.cuda.is_available() and allow_cuda else 'cpu') def load(self, path): self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device) self.tokenizer = CharTokenizer.from_pretrained(path) def render_stress(self, word, token_classes): if 'STRESS' in token_classes: index = token_classes.index('STRESS') word = list(word) word[index-1] = '+' + word[index-1] return ''.join(word) else: return word def put_accent(self, word): inputs = self.tokenizer(word, return_tensors="pt").to(self.device) with torch.no_grad(): logits = self.model(**inputs).logits predictions = torch.argmax(logits, dim=2) predicted_token_class = [self.model.config.id2label[t.item()] for t in predictions[0]] return self.render_stress(word, predicted_token_class)