File size: 5,284 Bytes
71c1eea
7b2ea28
71c1eea
58d8dcc
ac0b738
ac63630
71c1eea
560d773
58d8dcc
560d773
f200db2
b74f79b
 
560d773
 
 
 
f200db2
ac0b738
560d773
 
8a69f7f
76a49b5
 
 
560d773
76a49b5
560d773
 
e049916
a9f022a
e049916
a9f022a
 
 
 
 
 
e049916
a9f022a
 
e049916
 
560d773
01d2282
560d773
 
 
 
20fe8d9
560d773
 
 
 
 
 
 
 
 
20fe8d9
560d773
8a69f7f
7599afa
7044ac1
 
560d773
20fe8d9
560d773
 
 
7044ac1
560d773
8a69f7f
7044ac1
 
8a69f7f
7044ac1
8a69f7f
9ab5792
8a69f7f
 
 
 
 
20fe8d9
8a69f7f
 
 
560d773
 
 
 
 
 
 
 
 
 
 
 
 
 
8a69f7f
560d773
8a69f7f
560d773
defe81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
---
inference: false
license: mit
tags:
- Zero-Shot Classification
pipeline_tag: zero-shot-classification
---
# Zero-shot text classification (base-sized model) trained with self-supervised tuning

Zero-shot text classification model trained with self-supervised tuning (SSTuning). 
It was introduced in the paper [Zero-Shot Text Classification via Self-Supervised Tuning](https://arxiv.org/abs/2305.11442) by 
Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing
and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning).

The model backbone is RoBERTa-base.

## Model description


The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP). 
The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks. 

The training and validation sets are constructed from the unlabeled corpus using FSP. 

During tuning, BERT-like pre-trained masked language 
models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added. 
The learning objective for FSP is to predict the index of the correct label. 
A cross-entropy loss is used for tuning the model.

## Model variations
There are four versions of models released. The details are: 

| Model | Backbone | #params | lang | acc | Speed | #Train
|------------|-----------|----------|-------|-------|----|-------------|
|   [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base)    |  [roberta-base](https://huggingface.co/roberta-base)      |  125M    | En | Low    |  High    | 20.48M |  
|   [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large)    |    [roberta-large](https://huggingface.co/roberta-large)      | 355M     | En |   Medium   | Medium | 5.12M |
|   [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT)   |  [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2)      |  235M   | En |  High  | Low| 5.12M |
|   [zero-shot-classify-SSTuning-XLM-R](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R)    |  [xlm-roberta-base](https://huggingface.co/xlm-roberta-base)      |  278M    | Multi | -   |  -    | 20.48M |  

Please note that zero-shot-classify-SSTuning-XLM-R is trained with 20.48M English samples only. However, it can also be used in other languages as long as xlm-roberta supports.
Please check [this repository](https://github.com/DAMO-NLP-SG/SSTuning) for the performance of each model.


## Intended uses & limitations
The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.

The number of labels should be 2 ~ 20. 

### How to use
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random

tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")

text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]

def check_text(model, text, list_label, shuffle=False): 
    list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
    list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
    if shuffle: 
        random.shuffle(list_label_new)
    s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
    text = f'{s_option} {tokenizer.sep_token} {text}'

    model.to(device).eval()
    encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
    item = {key: val.to(device) for key, val in encoding.items()}
    logits = model(**item).logits
    
    logits = logits if shuffle else logits[:,0:len(list_label)]
    probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
    predictions = torch.argmax(logits, dim=-1).item() 
    probabilities = [round(x,5) for x in probs[0]]

    print(f'prediction:    {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
    print(f'probability:   {round(probabilities[predictions]*100,2)}%')

check_text(model, text, list_label)
# prediction:    1 => (B) positive.
# probability:   99.92%
```


### BibTeX entry and citation info
```bibtxt
@inproceedings{acl23/SSTuning,
  author    = {Chaoqun Liu and
               Wenxuan Zhang and
               Guizhen Chen and
               Xiaobao Wu and
               Anh Tuan Luu and
               Chip Hong Chang and 
               Lidong Bing},
  title     = {Zero-Shot Text Classification via Self-Supervised Tuning},
  booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
  year      = {2023},
  url       = {https://arxiv.org/abs/2305.11442},
}
```