Update ruaccent/omograph_model.py
Browse files
ruaccent/omograph_model.py
CHANGED
@@ -2,8 +2,8 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
2 |
import torch
|
3 |
|
4 |
class OmographModel:
|
5 |
-
def __init__(self) -> None:
|
6 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
|
8 |
def load(self, path):
|
9 |
self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
|
|
|
2 |
import torch
|
3 |
|
4 |
class OmographModel:
|
5 |
+
def __init__(self, allow_cuda=True) -> None:
|
6 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() and allow_cuda else 'cpu')
|
7 |
|
8 |
def load(self, path):
|
9 |
self.nli_model = AutoModelForSequenceClassification.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
|