lukecq commited on
Commit
20fe8d9
1 Parent(s): 9ab5792

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -5
README.md CHANGED
@@ -25,7 +25,7 @@ The model can be used for zero-shot text classification such sentiment analysis
25
  The number of labels should be 2 ~ 20.
26
 
27
  ### How to use
28
- You can try the model with the colab [notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
29
 
30
  ```python
31
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -35,12 +35,12 @@ tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuni
35
  model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
36
 
37
  text = "I love this place! The food is always so fresh and delicious."
38
- list_label = ["negative","positve"]
39
 
40
  list_ABC = [x for x in string.ascii_uppercase]
41
- def add_prefix(text, list_label, label_num = 20, shuffle = False):
42
  list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
43
- list_label_new = list_label + [tokenizer.pad_token]* (label_num - len(list_label))
44
  if shuffle:
45
  random.shuffle(list_label_new)
46
  s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
@@ -50,9 +50,12 @@ text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
50
 
51
  encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
52
  with torch.no_grad():
53
- logits = model(**item).logits
54
  probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
55
  predictions = torch.argmax(logits, dim=-1)
 
 
 
56
  ```
57
 
58
 
 
25
  The number of labels should be 2 ~ 20.
26
 
27
  ### How to use
28
+ You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
29
 
30
  ```python
31
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
35
  model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
36
 
37
  text = "I love this place! The food is always so fresh and delicious."
38
+ list_label = ["negative", "positive"]
39
 
40
  list_ABC = [x for x in string.ascii_uppercase]
41
+ def add_prefix(text, list_label, shuffle = False):
42
  list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
43
+ list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
44
  if shuffle:
45
  random.shuffle(list_label_new)
46
  s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
 
50
 
51
  encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
52
  with torch.no_grad():
53
+ logits = model(**encoding).logits
54
  probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
55
  predictions = torch.argmax(logits, dim=-1)
56
+
57
+ print(probs)
58
+ print(predictions)
59
  ```
60
 
61