TeraSpace commited on
Commit
ed20cbd
1 Parent(s): 328b51c

Update ruaccent/accent_model.py

Browse files
Files changed (1) hide show
  1. ruaccent/accent_model.py +2 -2
ruaccent/accent_model.py CHANGED
@@ -3,8 +3,8 @@ from .char_tokenizer import CharTokenizer
3
  from transformers import AutoModelForTokenClassification
4
 
5
  class AccentModel:
6
- def __init__(self) -> None:
7
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  def load(self, path):
9
  self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
10
  self.tokenizer = CharTokenizer.from_pretrained(path)
 
3
  from transformers import AutoModelForTokenClassification
4
 
5
  class AccentModel:
6
+ def __init__(self, allow_cuda=True) -> None:
7
+ self.device = torch.device('cuda' if torch.cuda.is_available() and allow_cuda else 'cpu')
8
  def load(self, path):
9
  self.model = AutoModelForTokenClassification.from_pretrained(path).to(self.device)
10
  self.tokenizer = CharTokenizer.from_pretrained(path)