Update README.md
Browse files
README.md
CHANGED
@@ -20,24 +20,74 @@ This model has been trained on the following datasets:
|
|
20 |
|
21 |
# Use
|
22 |
|
23 |
-
*
|
24 |
```python
|
25 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
26 |
-
from transformers import pipeline
|
27 |
model_id = "gauneg/roberta-base-absa-ate-sentiment"
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
29 |
model = AutoModelForTokenClassification.from_pretrained(model_id)
|
30 |
|
31 |
-
ate_sent_pipeline = pipeline(task='ner',
|
32 |
-
aggregation_strategy='simple',
|
33 |
-
tokenizer=tokenizer,
|
34 |
-
model=model)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
|
|
37 |
|
38 |
```
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
text_input = "Been here a few times and food has always been good but service really suffers when it gets crowded."
|
42 |
ate_sent_pipeline(text_input)
|
43 |
```
|
|
|
20 |
|
21 |
# Use
|
22 |
|
23 |
+
* Making token level inferences with Auto classes
|
24 |
```python
|
25 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
|
26 |
model_id = "gauneg/roberta-base-absa-ate-sentiment"
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
28 |
model = AutoModelForTokenClassification.from_pretrained(model_id)
|
29 |
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
# the sequence of labels used during training
|
32 |
+
label = {"B-neu": 1, "I-neu": 2, "O": 0, "B-neg": 4, "B-con": 5, "I-pos": 6, "B-pos": 7, "I-con": 8, "I-neg": 9, "X": -100}
|
33 |
+
id2lab = {idx: lab for lab, idx in labels.items()}
|
34 |
+
lab2id = {lab: idx for lab, idx in labels.items()}
|
35 |
+
|
36 |
+
|
37 |
+
# making one prediction at a time (should be padded/batched and truncated for efficiency)
|
38 |
+
text_input = "Been here a few times and food has always been good but service really suffers when it gets crowded."
|
39 |
+
tok_inputs = tokenizer(text_input, return_tensors="pt")
|
40 |
+
|
41 |
+
|
42 |
+
y_pred = model(**tok_inputs) # predicting the logits
|
43 |
+
|
44 |
+
y_pred_fin = y_pred.logits.argmax(dim=-1)[0] # selecting the most favoured labels for each token from the logits
|
45 |
+
|
46 |
+
decoded_pred = [id2lab[logx.item()] for logx in y_pred_fin]
|
47 |
+
|
48 |
+
|
49 |
+
## displaying the input tokens with predictions and skipping <s> and </s> tokens at the beginning and the end respectively
|
50 |
|
51 |
+
tok_levl_pred = list(zip(tokenizer.convert_ids_to_tokens(tok_inputs['input_ids'][0]), decoded_pred))[1:-1]
|
52 |
|
53 |
```
|
54 |
+
|
55 |
+
* results in `tok_level_pred` variable
|
56 |
+
|
57 |
+
```bash
|
58 |
+
[('Be', 'O'),
|
59 |
+
('en', 'O'),
|
60 |
+
('Ġhere', 'O'),
|
61 |
+
('Ġa', 'O'),
|
62 |
+
('Ġfew', 'O'),
|
63 |
+
('Ġtimes', 'O'),
|
64 |
+
('Ġand', 'O'),
|
65 |
+
('Ġfood', 'B-pos'),
|
66 |
+
('Ġhas', 'O'),
|
67 |
+
('Ġalways', 'O'),
|
68 |
+
('Ġbeen', 'O'),
|
69 |
+
('Ġgood', 'O'),
|
70 |
+
('Ġbut', 'O'),
|
71 |
+
('Ġservice', 'B-neg'),
|
72 |
+
('Ġreally', 'O'),
|
73 |
+
('Ġsuffers', 'O'),
|
74 |
+
('Ġwhen', 'O'),
|
75 |
+
('Ġit', 'O'),
|
76 |
+
('Ġgets', 'O'),
|
77 |
+
('Ġcrowded', 'O'),
|
78 |
+
('.', 'O')]
|
79 |
+
```
|
80 |
+
|
81 |
+
# OR
|
82 |
+
|
83 |
+
* Using the pipeline directly for end-to-end inference:
|
84 |
```python
|
85 |
+
from transformers import pipeline
|
86 |
+
|
87 |
+
ate_sent_pipeline = pipeline(task='ner',
|
88 |
+
aggregation_strategy='simple',
|
89 |
+
model="gauneg/roberta-base-absa-ate-sentiment")
|
90 |
+
|
91 |
text_input = "Been here a few times and food has always been good but service really suffers when it gets crowded."
|
92 |
ate_sent_pipeline(text_input)
|
93 |
```
|