Update ruaccent/accent_model.py
Browse files- ruaccent/accent_model.py +2 -2
ruaccent/accent_model.py
CHANGED
@@ -3,8 +3,8 @@ from .char_tokenizer import CharTokenizer
|
|
3 |
from transformers import AutoModelForTokenClassification
|
4 |
|
5 |
class AccentModel:
|
6 |
-
def __init__(self) -> None:
|
7 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
def load(self, path):
|
9 |
self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
|
10 |
self.tokenizer = CharTokenizer.from_pretrained(path)
|
|
|
3 |
from transformers import AutoModelForTokenClassification
|
4 |
|
5 |
class AccentModel:
|
6 |
+
def __init__(self, allow_cuda=True) -> None:
|
7 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() and allow_cuda else 'cpu')
|
8 |
def load(self, path):
|
9 |
self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
|
10 |
self.tokenizer = CharTokenizer.from_pretrained(path)
|