Testys commited on
Commit
31cea2f
1 Parent(s): 2210a0e

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -40
main.py CHANGED
@@ -36,50 +36,40 @@ sentiment_model = SentimentCNNModel(
36
  sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
37
  sentiment_model.eval()
38
 
39
- def analyze_text(text, window_size=512, stride=256):
40
- # Initialize results
41
- all_ner_labels = []
42
- all_sentiments = []
43
-
44
- # Process text in windows
45
- for i in range(0, len(text), stride):
46
- window = text[i:i+window_size]
47
-
48
- # Tokenize input text for NER
49
- ner_inputs = ner_tokenizer(window, return_tensors="pt", truncation=True, padding=True, max_length=window_size)
50
-
51
- input_ids = ner_inputs['input_ids']
52
- tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]
53
 
54
- # Perform Named Entity Recognition
55
- with torch.no_grad():
56
- ner_outputs = ner_model(**ner_inputs)
57
-
58
- ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
59
- ner_labels = ner_predictions.tolist()
60
- ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]
61
- ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]
62
-
63
- all_ner_labels.extend(ner_labels) # Adjust logic to merge overlapping windows appropriately
 
 
 
 
 
 
 
64
 
65
- # Tokenize input text for sentiment analysis
66
- sentiment_inputs = sentiment_tokenizer(window, return_tensors="pt", truncation=True, padding=True, max_length=window_size)
67
 
68
- # Perform sentiment analysis
69
- with torch.no_grad():
70
- sentiment_outputs = sentiment_model(**sentiment_inputs)
71
- sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
72
- sentiment_scores = sentiment_probabilities.tolist()
73
- sentiment_id = sentiment_scores[0]
74
- sentiment = sentiment_config["id2label"][str(sentiment_id)]
75
- all_sentiments.append(sentiment) # This needs logic to combine sentiment over windows
76
 
77
- # For simplicity, aggregate sentiments by majority vote
78
- from collections import Counter
79
- sentiment_counts = Counter(all_sentiments)
80
- final_sentiment = sentiment_counts.most_common(1)[0][0]
 
 
 
81
 
82
- return all_ner_labels, final_sentiment
83
 
84
  def main():
85
  st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
@@ -149,4 +139,4 @@ def main():
149
  """, unsafe_allow_html=True)
150
 
151
  if __name__ == "__main__":
152
- main()
 
36
  sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
37
  sentiment_model.eval()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def analyze_text(text):
41
+ # Tokenize input text for NER
42
+ ner_inputs = ner_tokenizer(text, return_tensors="pt")
43
+
44
+ input_ids = ner_inputs['input_ids']
45
+
46
+ # Converting token IDs back to tokens
47
+ tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]
48
+
49
+
50
+ # Perform Named Entity Recognition
51
+ with torch.no_grad():
52
+ ner_outputs = ner_model(**ner_inputs)
53
+
54
+ ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
55
+ ner_labels = ner_predictions.tolist()
56
+ ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]
57
 
58
+ #matching the tokens with the labels
59
+ ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]
60
 
61
+ # Tokenize input text for sentiment analysis
62
+ sentiment_inputs = sentiment_tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
63
 
64
+ # Perform sentiment analysis
65
+ with torch.no_grad():
66
+ sentiment_outputs = sentiment_model(**sentiment_inputs)
67
+ sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
68
+ sentiment_scores = sentiment_probabilities.tolist()
69
+ sentiment_id = sentiment_scores[0]
70
+ sentiment = sentiment_config["id2label"][str(sentiment_id)]
71
 
72
+ return ner_labels, sentiment
73
 
74
  def main():
75
  st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
 
139
  """, unsafe_allow_html=True)
140
 
141
  if __name__ == "__main__":
142
+ main()