jishnunair
commited on
Commit
•
390fd8d
1
Parent(s):
ad13f83
Update README.md
Browse files
README.md
CHANGED
@@ -36,8 +36,42 @@ It achieves the following results on the evaluation set:
|
|
36 |
|
37 |
The training data consists of the 4 most widely available ner_tags from the Finer-139 dataset. The training and the test data were curated from this source accordingly
|
38 |
|
39 |
-
##
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
### Training hyperparameters
|
42 |
|
43 |
The following hyperparameters were used during training:
|
|
|
36 |
|
37 |
The training data consists of the 4 most widely available ner_tags from the Finer-139 dataset. The training and the test data were curated from this source accordingly
|
38 |
|
39 |
+
## Prediction procedure
|
40 |
+
```
|
41 |
+
from transformers import TAutoTokenizer
|
42 |
+
from optimum.onnxruntime import ORTModelForTokenClassification
|
43 |
+
import torch
|
44 |
|
45 |
+
def onnx_inference(checkpoint, test_data, export=False):
|
46 |
+
test_text = " ".join(test_data['tokens'])
|
47 |
+
print("Test Text: " + test_text)
|
48 |
+
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
50 |
+
model = ORTModelForTokenClassification.from_pretrained(checkpoint, export=export)
|
51 |
+
|
52 |
+
inputs = tokenizer(test_text, return_tensors="pt")
|
53 |
+
outputs = model(**inputs).logits
|
54 |
+
|
55 |
+
predictions = torch.argmax(outputs, dim=2)
|
56 |
+
|
57 |
+
# Convert each tensor element to a scalar before calling .item()
|
58 |
+
predicted_token_class = [label_list[int(t)] for t in predictions[0]]
|
59 |
+
ner_tags = [label_list[int(t)] for t in test_data['ner_tags']]
|
60 |
+
|
61 |
+
print("Original Tags: ")
|
62 |
+
print(ner_tags)
|
63 |
+
print("Predicted Tags: ")
|
64 |
+
print(predicted_token_class)
|
65 |
+
|
66 |
+
onnx_model_path = "" #add the path
|
67 |
+
|
68 |
+
onnx_inference(onnx_model_path, test_data)
|
69 |
+
|
70 |
+
"""
|
71 |
+
Here the test_data should contain "tokens" and "ner_tags". This can be of type Dataset.
|
72 |
+
"""
|
73 |
+
|
74 |
+
```
|
75 |
### Training hyperparameters
|
76 |
|
77 |
The following hyperparameters were used during training:
|