Update README.md
Browse files
README.md
CHANGED
@@ -25,7 +25,7 @@ The model can be used for zero-shot text classification such sentiment analysis
|
|
25 |
The number of labels should be 2 ~ 20.
|
26 |
|
27 |
### How to use
|
28 |
-
You can try the model with the
|
29 |
|
30 |
```python
|
31 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
@@ -35,12 +35,12 @@ tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuni
|
|
35 |
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
36 |
|
37 |
text = "I love this place! The food is always so fresh and delicious."
|
38 |
-
list_label = ["negative","
|
39 |
|
40 |
list_ABC = [x for x in string.ascii_uppercase]
|
41 |
-
def add_prefix(text, list_label,
|
42 |
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
43 |
-
list_label_new = list_label + [tokenizer.pad_token]* (
|
44 |
if shuffle:
|
45 |
random.shuffle(list_label_new)
|
46 |
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
@@ -50,9 +50,12 @@ text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
|
|
50 |
|
51 |
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
|
52 |
with torch.no_grad():
|
53 |
-
logits = model(**
|
54 |
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
55 |
predictions = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
|
56 |
```
|
57 |
|
58 |
|
|
|
25 |
The number of labels should be 2 ~ 20.
|
26 |
|
27 |
### How to use
|
28 |
+
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
|
29 |
|
30 |
```python
|
31 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
35 |
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
36 |
|
37 |
text = "I love this place! The food is always so fresh and delicious."
|
38 |
+
list_label = ["negative", "positive"]
|
39 |
|
40 |
list_ABC = [x for x in string.ascii_uppercase]
|
41 |
+
def add_prefix(text, list_label, shuffle = False):
|
42 |
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
43 |
+
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
|
44 |
if shuffle:
|
45 |
random.shuffle(list_label_new)
|
46 |
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
|
|
50 |
|
51 |
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
|
52 |
with torch.no_grad():
|
53 |
+
logits = model(**encoding).logits
|
54 |
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
55 |
predictions = torch.argmax(logits, dim=-1)
|
56 |
+
|
57 |
+
print(probs)
|
58 |
+
print(predictions)
|
59 |
```
|
60 |
|
61 |
|