Spaces:
Running
Running
Chidam Gopal
commited on
Commit
•
2cba4b1
1
Parent(s):
7538db6
directly use the onnx quantized file
Browse files- infer_intent.py +19 -5
- requirements.txt +3 -1
infer_intent.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
-
from transformers import
|
2 |
import torch
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class IntentClassifier:
|
@@ -15,10 +19,20 @@ class IntentClassifier:
|
|
15 |
self.label2id = {label:id for id,label in self.id2label.items()}
|
16 |
|
17 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def find_intent(self, sequence, verbose=False):
|
24 |
inputs = self.tokenizer(sequence,
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
import torch
|
3 |
+
import onnxruntime as ort
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import os
|
7 |
|
8 |
|
9 |
class IntentClassifier:
|
|
|
19 |
self.label2id = {label:id for id,label in self.id2label.items()}
|
20 |
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
22 |
+
model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
|
23 |
+
model_dir_path = "models"
|
24 |
+
model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
|
25 |
+
if not os.path.exists(model_dir_path):
|
26 |
+
os.makedirs(model_dir_path)
|
27 |
+
if not os.path.exists(model_path):
|
28 |
+
print("Downloading ONNX model...")
|
29 |
+
response = requests.get(model_url)
|
30 |
+
with open(model_path, "wb") as f:
|
31 |
+
f.write(response.content)
|
32 |
+
print("ONNX model downloaded.")
|
33 |
+
|
34 |
+
# Load the ONNX model
|
35 |
+
self.ort_session = ort.InferenceSession(model_path)
|
36 |
|
37 |
def find_intent(self, sequence, verbose=False):
|
38 |
inputs = self.tokenizer(sequence,
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
transformers==4.45.1
|
2 |
torch==2.4.1
|
3 |
streamlit==1.38.0
|
4 |
-
matplotlib==3.9.2
|
|
|
|
|
|
1 |
transformers==4.45.1
|
2 |
torch==2.4.1
|
3 |
streamlit==1.38.0
|
4 |
+
matplotlib==3.9.2
|
5 |
+
## onnx
|
6 |
+
onnxruntime==1.19.2
|