File size: 1,722 Bytes
f3d8f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import time
import streamlit as st
from transformers import pipeline

# Model to load
MODEL_TO_LOAD = "swastik-kapture/offenseval-xlmr"
TOKENIZER = "xlm-roberta-base"

# create classification pipeline
trained_model = pipeline("text-classification", model=MODEL_TO_LOAD, tokenizer=TOKENIZER)

# Streamlit App
def main():
    # create a session state for conversation history
    if 'conversation_history' not in st.session_state:
        st.session_state.conversation_history = []
    # streamlit title
    st.title("OffensEval: Profanity Detection")
    # user message
    user_message = st.chat_input("Say something")
    # if user input is present try to predict the outcome
    if user_message:
        # append user message to history
        st.session_state.conversation_history.append(('user', user_message, time.time()))
        # get predicted output
        output = trained_model.predict(user_message)
        # get predictied label and score
        label = output[0]['label']
        score = output[0]['score']
        # default color
        color = "white"
        # get the color based on label
        if label == "LABEL_0":
            color = "green"
            label = "No Offense"
        elif label == "LABEL_1":
            color = "red"
            label = "Offensive"
        st.session_state.conversation_history.append(('assistant', f"<div style='background-color: {color}; width: auto; height: 50px;'>Label: {label}; Score: {score:.2f}</div>", time.time()))
    # Display chat history
    for sender, message, timestamp in st.session_state.conversation_history:
        with st.chat_message(sender):
            st.write(message, unsafe_allow_html=True)

if __name__ == "__main__":
    main()