|
import streamlit as st |
|
import json |
|
import torch |
|
from transformers import AutoTokenizer |
|
from modelling_cnn import CNNForNER, SentimentCNNModel |
|
|
|
|
|
|
|
ner_model_name = "./my_model/pytorch_model.bin" |
|
model_ner = "Testys/cnn_yor_ner" |
|
ner_tokenizer = AutoTokenizer.from_pretrained(model_ner) |
|
with open("./my_model/config.json", "r") as f: |
|
ner_config = json.load(f) |
|
|
|
ner_model = CNNForNER( |
|
pretrained_model_name=ner_config["pretrained_model_name"], |
|
num_classes=ner_config["num_classes"] |
|
) |
|
ner_model.load_state_dict(torch.load(ner_model_name, map_location=torch.device('cpu'))) |
|
ner_model.eval() |
|
|
|
|
|
sentiment_model_name = "./sent_model/sent_pytorch_model.bin" |
|
model_sent = "Testys/cnn_sent_yor" |
|
sentiment_tokenizer = AutoTokenizer.from_pretrained(model_sent) |
|
|
|
with open("./sent_model/config.json", "r") as f: |
|
sentiment_config = json.load(f) |
|
|
|
sentiment_model = SentimentCNNModel( |
|
transformer_model_name=sentiment_config["pretrained_model_name"], |
|
num_classes=sentiment_config["num_classes"] |
|
) |
|
|
|
sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu'))) |
|
sentiment_model.eval() |
|
|
|
|
|
def analyze_text(text): |
|
|
|
ner_inputs = ner_tokenizer(text, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
ner_outputs = ner_model(**ner_inputs) |
|
|
|
ner_predictions = torch.argmax(ner_outputs.logits, dim=-1) |
|
ner_labels = [ner_tokenizer.decode(token) for token in ner_predictions[0]] |
|
|
|
|
|
sentiment_inputs = sentiment_tokenizer.encode_plus(text, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
sentiment_outputs = sentiment_model(**sentiment_inputs) |
|
sentiment_probabilities = torch.softmax(sentiment_outputs.logits, dim=1) |
|
sentiment_scores = sentiment_probabilities.tolist() |
|
|
|
return ner_labels, sentiment_scores |
|
|
|
def main(): |
|
st.title("YorubaCNN Models for NER and Sentiment Analysis") |
|
|
|
|
|
text = st.text_area("Enter Yoruba text", "") |
|
|
|
if st.button("Analyze"): |
|
if text: |
|
ner_labels, sentiment_scores = analyze_text(text) |
|
|
|
|
|
st.subheader("Named Entities") |
|
for label in ner_labels: |
|
st.write(f"- {label}") |
|
|
|
|
|
st.subheader("Sentiment Analysis") |
|
st.write(f"Positive: {sentiment_scores[2]:.2f}") |
|
st.write(f"Negative: {sentiment_scores[0]:.2f}") |
|
st.write(f"Neutral: {sentiment_scores[1]:.2f}") |
|
|
|
if __name__ == "__main__": |
|
main() |