Testys commited on
Commit
b8cf6ae
1 Parent(s): 73b342f

Commiting for the Changes made for config.json and main.py

Browse files
Files changed (3) hide show
  1. main.py +83 -1
  2. my_model/config.json +7 -1
  3. sent_model/config.json +1 -1
main.py CHANGED
@@ -1 +1,83 @@
1
- print('Hello, Lightning World!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import torch
4
+ from transformers import AutoTokenizer
5
+ from modelling_cnn import CNNForNER, SentimentCNNModel
6
+
7
+
8
+ # Load the Yoruba NER model
9
+ ner_model_name = "./my_model/pytorch_model.bin"
10
+ model_ner = "Testys/cnn_yor_ner"
11
+ ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
12
+ with open("./my_model/config.json", "r") as f:
13
+ ner_config = json.load(f)
14
+
15
+ ner_model = CNNForNER(
16
+ pretrained_model_name=ner_config["pretrained_model_name"],
17
+ num_classes=ner_config["num_classes"]
18
+ )
19
+ ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
20
+ ner_model.eval()
21
+
22
+ # Load the Yoruba sentiment analysis model
23
+ sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
24
+ model_sent = "Testys/cnn_sent_yor"
25
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)
26
+
27
+ with open("./sent_model/config.json", "r") as f:
28
+ sentiment_config = json.load(f)
29
+
30
+ sentiment_model = SentimentCNNModel(
31
+ transformer_model_name=sentiment_config["pretrained_model_name"],
32
+ num_classes=sentiment_config["num_classes"]
33
+ )
34
+
35
+ sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
36
+ sentiment_model.eval()
37
+
38
+
39
+ def analyze_text(text):
40
+ # Tokenize input text for NER
41
+ ner_inputs = ner_tokenizer(text, return_tensors="pt")
42
+
43
+ # Perform Named Entity Recognition
44
+ with torch.no_grad():
45
+ ner_outputs = ner_model(**ner_inputs)
46
+
47
+ ner_predictions = torch.argmax(ner_outputs.logits, dim=-1)
48
+ ner_labels = [ner_tokenizer.decode(token) for token in ner_predictions[0]]
49
+
50
+ # Tokenize input text for sentiment analysis
51
+ sentiment_inputs = sentiment_tokenizer.encode_plus(text, return_tensors="pt")
52
+
53
+ # Perform sentiment analysis
54
+ with torch.no_grad():
55
+ sentiment_outputs = sentiment_model(**sentiment_inputs)
56
+ sentiment_probabilities = torch.softmax(sentiment_outputs.logits, dim=1)
57
+ sentiment_scores = sentiment_probabilities.tolist()
58
+
59
+ return ner_labels, sentiment_scores
60
+
61
+ def main():
62
+ st.title("YorubaCNN Models for NER and Sentiment Analysis")
63
+
64
+ # Input text
65
+ text = st.text_area("Enter Yoruba text", "")
66
+
67
+ if st.button("Analyze"):
68
+ if text:
69
+ ner_labels, sentiment_scores = analyze_text(text)
70
+
71
+ # Display Named Entities
72
+ st.subheader("Named Entities")
73
+ for label in ner_labels:
74
+ st.write(f"- {label}")
75
+
76
+ # Display Sentiment Analysis
77
+ st.subheader("Sentiment Analysis")
78
+ st.write(f"Positive: {sentiment_scores[2]:.2f}")
79
+ st.write(f"Negative: {sentiment_scores[0]:.2f}")
80
+ st.write(f"Neutral: {sentiment_scores[1]:.2f}")
81
+
82
+ if __name__ == "__main__":
83
+ main()
my_model/config.json CHANGED
@@ -1 +1,7 @@
1
- {"model_type": "CNNForYorubaNER", "num_classes": 9, "max_length": 128, "pretrained_model_name": "masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0"}
 
 
 
 
 
 
 
1
+ {"model_type": "CNNForYorubaNER",
2
+ "num_classes": 9,
3
+ "max_length": 128,
4
+ "pretrained_model_name": "masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0",
5
+ "id2labels": {"0": "B-DATE", "1": "B-LOC", "2": "B-ORG", "3": "B-PER", "4": "I-DATE", "5": "I-LOC", "6": "I-ORG", "7": "I-PER", "8": "O"},
6
+ "label2id": {"B-DATE": 0, "B-LOC": 1, "B-ORG": 2, "B-PER": 3, "I-DATE": 4, "I-LOC": 5, "I-ORG": 6, "I-PER": 7, "O": 8}
7
+ }
sent_model/config.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "CNNForSentimentAnalysis", "num_classes": 2, "max_length": 128, "pretrained_model_name": "Davlan/naija-twitter-sentiment-afriberta-large"}
 
1
+ {"model_type": "CNNForSentimentAnalysis", "num_classes": 2, "max_length": 128, "pretrained_model_name": "Davlan/naija-twitter-sentiment-afriberta-large", "id2label": {"0": "negative", "1": "positive"}, "label2id": {"negative": 0, "positive": 1}}