File size: 1,156 Bytes
dfc143a
 
 
 
 
ed20cbd
 
dfc143a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from .char_tokenizer import CharTokenizer
from transformers import AutoModelForTokenClassification

class AccentModel:
    def __init__(self, allow_cuda=True) -> None:
        self.device = torch.device('cuda' if torch.cuda.is_available()  and allow_cuda else 'cpu')
    def load(self, path):
        self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
        self.tokenizer = CharTokenizer.from_pretrained(path)
    
    def render_stress(self, word, token_classes):
        if 'STRESS' in token_classes:
            index = token_classes.index('STRESS')
            word = list(word)
            word[index-1] = '+' + word[index-1]
            return ''.join(word)
        else:
            return word
    
    def put_accent(self, word):
        inputs = self.tokenizer(word, return_tensors="pt").to(self.device)
        with torch.no_grad():
            logits = self.model(**inputs).logits
            predictions = torch.argmax(logits, dim=2)
            predicted_token_class = [self.model.config.id2label[t.item()] for t in predictions[0]]
        return self.render_stress(word, predicted_token_class)