oliverguhr commited on
Commit
70e394e
1 Parent(s): 2fcb9be

fixed missing attention mask code

Browse files
Files changed (1) hide show
  1. README.md +5 -8
README.md CHANGED
@@ -69,16 +69,13 @@ class SentimentModel():
69
  def predict_sentiment(self, texts: List[str])-> List[str]:
70
  texts = [self.clean_text(text) for text in texts]
71
  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
72
- input_ids = self.tokenizer(texts, padding=True, truncation=True, add_special_tokens=True)
73
- input_ids = torch.tensor(input_ids["input_ids"])
74
-
75
  with torch.no_grad():
76
- logits = self.model(input_ids)
77
-
78
  label_ids = torch.argmax(logits[0], axis=1)
79
-
80
- labels = [self.model.config.id2label[label_id] for label_id in label_ids.tolist()]
81
- return labels
82
 
83
  def replace_numbers(self,text: str) -> str:
84
  return text.replace("0"," null").replace("1"," eins").replace("2"," zwei").replace("3"," drei").replace("4"," vier").replace("5"," fünf").replace("6"," sechs").replace("7"," sieben").replace("8"," acht").replace("9"," neun")
 
69
  def predict_sentiment(self, texts: List[str])-> List[str]:
70
  texts = [self.clean_text(text) for text in texts]
71
  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
72
+ encoded = self.tokenizer.batch_encode_plus(texts,padding=True, add_special_tokens=True,truncation=True, return_tensors="pt")
73
+ encoded = encoded.to(self.device)
 
74
  with torch.no_grad():
75
+ logits = self.model(**encoded)
76
+
77
  label_ids = torch.argmax(logits[0], axis=1)
78
+ return [self.model.config.id2label[label_id.item()] for label_id in label_ids]
 
 
79
 
80
  def replace_numbers(self,text: str) -> str:
81
  return text.replace("0"," null").replace("1"," eins").replace("2"," zwei").replace("3"," drei").replace("4"," vier").replace("5"," fünf").replace("6"," sechs").replace("7"," sieben").replace("8"," acht").replace("9"," neun")