File size: 4,381 Bytes
71c1eea 58d8dcc ac0b738 ac63630 71c1eea 560d773 58d8dcc 560d773 f200db2 b74f79b 560d773 f200db2 ac0b738 560d773 76a49b5 560d773 76a49b5 560d773 e049916 da5f2b5 ac63630 da5f2b5 e049916 01d2282 e049916 560d773 01d2282 560d773 20fe8d9 560d773 20fe8d9 560d773 7599afa 20fe8d9 560d773 20fe8d9 560d773 9ab5792 20fe8d9 9ab5792 20fe8d9 560d773 c3f51d9 560d773 defe81c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
---
license: mit
tags:
- Zero-Shot Classification
pipeline_tag: zero-shot-classification
---
# Zero-shot text classification (base-sized model) trained with self-supervised tuning
Zero-shot text classification model trained with self-supervised tuning (SSTuning).
It was introduced in the paper [Zero-Shot Text Classification via Self-Supervised Tuning](https://arxiv.org/abs/2305.11442) by
Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing
and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning).
The model backbone is RoBERTa-base.
## Model description
The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP).
The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
The training and validation sets are constructed from the unlabeled corpus using FSP.
During tuning, BERT-like pre-trained masked language
models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added.
The learning objective for FSP is to predict the index of the correct label.
A cross-entropy loss is used for tuning the model.
## Model variations
There are three versions of models released. The details are:
| Model | Backbone | #params | accuracy | Speed | #Training data
|------------|-----------|----------|-------|-------|----|
| [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base) | [roberta-base](https://huggingface.co/roberta-base
) | 125M | Low | High | 20.48M |
| [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large) | [roberta-large](https://huggingface.co/roberta-large) | 355M | Medium | Medium | 5.12M |
| [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) | [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2) | 235M | High | Low| 5.12M |
Please note that zero-shot-classify-SSTuning-base is trained with more data (20.48M) than the paper, as this will increase the accuracy.
## Intended uses & limitations
The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.
The number of labels should be 2 ~ 20.
### How to use
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random
tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]
list_ABC = [x for x in string.ascii_uppercase]
def add_prefix(text, list_label, shuffle = False):
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
if shuffle:
random.shuffle(list_label_new)
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
return f'{s_option} {tokenizer.sep_token} {text}', list_label_new
text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
with torch.no_grad():
logits = model(**encoding).logits
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
predictions = torch.argmax(logits, dim=-1)
print(probs)
print(predictions)
```
### BibTeX entry and citation info
```bibtxt
@inproceedings{acl23/SSTuning,
author = {Chaoqun Liu and
Wenxuan Zhang and
Guizhen Chen and
Xiaobao Wu and
Anh Tuan Luu and
Chip Hong Chang and
Lidong Bing},
title = {Zero-Shot Text Classification via Self-Supervised Tuning},
booktitle = {Findings of the 2023 ACL},
year = {2023},
url = {},
}
``` |