Commiting for the Changes made for config.json and main.py
Browse files- main.py +83 -1
- my_model/config.json +7 -1
- sent_model/config.json +1 -1
main.py
CHANGED
@@ -1 +1,83 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from modelling_cnn import CNNForNER, SentimentCNNModel
|
6 |
+
|
7 |
+
|
8 |
+
# Load the Yoruba NER model
|
9 |
+
ner_model_name = "./my_model/pytorch_model.bin"
|
10 |
+
model_ner = "Testys/cnn_yor_ner"
|
11 |
+
ner_tokenizer = AutoTokenizer.from_pretrained(model_ner)
|
12 |
+
with open("./my_model/config.json", "r") as f:
|
13 |
+
ner_config = json.load(f)
|
14 |
+
|
15 |
+
ner_model = CNNForNER(
|
16 |
+
pretrained_model_name=ner_config["pretrained_model_name"],
|
17 |
+
num_classes=ner_config["num_classes"]
|
18 |
+
)
|
19 |
+
ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu')))
|
20 |
+
ner_model.eval()
|
21 |
+
|
22 |
+
# Load the Yoruba sentiment analysis model
|
23 |
+
sentiment_model_name = "./sent_model/sent_pytorch_model.bin"
|
24 |
+
model_sent = "Testys/cnn_sent_yor"
|
25 |
+
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent)
|
26 |
+
|
27 |
+
with open("./sent_model/config.json", "r") as f:
|
28 |
+
sentiment_config = json.load(f)
|
29 |
+
|
30 |
+
sentiment_model = SentimentCNNModel(
|
31 |
+
transformer_model_name=sentiment_config["pretrained_model_name"],
|
32 |
+
num_classes=sentiment_config["num_classes"]
|
33 |
+
)
|
34 |
+
|
35 |
+
sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
|
36 |
+
sentiment_model.eval()
|
37 |
+
|
38 |
+
|
39 |
+
def analyze_text(text):
|
40 |
+
# Tokenize input text for NER
|
41 |
+
ner_inputs = ner_tokenizer(text, return_tensors="pt")
|
42 |
+
|
43 |
+
# Perform Named Entity Recognition
|
44 |
+
with torch.no_grad():
|
45 |
+
ner_outputs = ner_model(**ner_inputs)
|
46 |
+
|
47 |
+
ner_predictions = torch.argmax(ner_outputs.logits, dim=-1)
|
48 |
+
ner_labels = [ner_tokenizer.decode(token) for token in ner_predictions[0]]
|
49 |
+
|
50 |
+
# Tokenize input text for sentiment analysis
|
51 |
+
sentiment_inputs = sentiment_tokenizer.encode_plus(text, return_tensors="pt")
|
52 |
+
|
53 |
+
# Perform sentiment analysis
|
54 |
+
with torch.no_grad():
|
55 |
+
sentiment_outputs = sentiment_model(**sentiment_inputs)
|
56 |
+
sentiment_probabilities = torch.softmax(sentiment_outputs.logits, dim=1)
|
57 |
+
sentiment_scores = sentiment_probabilities.tolist()
|
58 |
+
|
59 |
+
return ner_labels, sentiment_scores
|
60 |
+
|
61 |
+
def main():
|
62 |
+
st.title("YorubaCNN Models for NER and Sentiment Analysis")
|
63 |
+
|
64 |
+
# Input text
|
65 |
+
text = st.text_area("Enter Yoruba text", "")
|
66 |
+
|
67 |
+
if st.button("Analyze"):
|
68 |
+
if text:
|
69 |
+
ner_labels, sentiment_scores = analyze_text(text)
|
70 |
+
|
71 |
+
# Display Named Entities
|
72 |
+
st.subheader("Named Entities")
|
73 |
+
for label in ner_labels:
|
74 |
+
st.write(f"- {label}")
|
75 |
+
|
76 |
+
# Display Sentiment Analysis
|
77 |
+
st.subheader("Sentiment Analysis")
|
78 |
+
st.write(f"Positive: {sentiment_scores[2]:.2f}")
|
79 |
+
st.write(f"Negative: {sentiment_scores[0]:.2f}")
|
80 |
+
st.write(f"Neutral: {sentiment_scores[1]:.2f}")
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
my_model/config.json
CHANGED
@@ -1 +1,7 @@
|
|
1 |
-
{"model_type": "CNNForYorubaNER",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"model_type": "CNNForYorubaNER",
|
2 |
+
"num_classes": 9,
|
3 |
+
"max_length": 128,
|
4 |
+
"pretrained_model_name": "masakhane/afroxlmr-large-ner-masakhaner-1.0_2.0",
|
5 |
+
"id2labels": {"0": "B-DATE", "1": "B-LOC", "2": "B-ORG", "3": "B-PER", "4": "I-DATE", "5": "I-LOC", "6": "I-ORG", "7": "I-PER", "8": "O"},
|
6 |
+
"label2id": {"B-DATE": 0, "B-LOC": 1, "B-ORG": 2, "B-PER": 3, "I-DATE": 4, "I-LOC": 5, "I-ORG": 6, "I-PER": 7, "O": 8}
|
7 |
+
}
|
sent_model/config.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"model_type": "CNNForSentimentAnalysis", "num_classes": 2, "max_length": 128, "pretrained_model_name": "Davlan/naija-twitter-sentiment-afriberta-large"}
|
|
|
1 |
+
{"model_type": "CNNForSentimentAnalysis", "num_classes": 2, "max_length": 128, "pretrained_model_name": "Davlan/naija-twitter-sentiment-afriberta-large", "id2label": {"0": "negative", "1": "positive"}, "label2id": {"negative": 0, "positive": 1}}
|