AMead10's picture
update col name
ce294e4
raw
history blame
8.15 kB
import gradio as gr
import mysql.connector
import os
# Use a pipeline as a high-level helper
from transformers import pipeline
classifier_model = pipeline(
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
)
# get db info from env vars
db_host = os.environ.get("DB_HOST")
db_user = os.environ.get("DB_USER")
db_pass = os.environ.get("DB_PASS")
db_name = os.environ.get("DB_NAME")
db_connection = mysql.connector.connect(
host=db_host,
user=db_user,
password=db_pass,
database=db_name,
)
db_cursor = db_connection.cursor()
ORG_ID = 731
potential_labels = []
def get_potential_labels():
# get potential labels from db
global potential_labels
potential_labels = db_cursor.execute(
"SELECT message_category_name FROM radmap_frog12.message_categorys"
)
potential_labels = db_cursor.fetchall()
potential_labels = [label[0] for label in potential_labels]
return potential_labels
potential_labels = get_potential_labels()
# Function to handle the classification
def classify_email(constituent_email):
potential_labels = get_potential_labels()
print("classifying email")
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
print("classification complete")
top_labels = [
label
for label, score in zip(model_out["labels"], model_out["scores"])
if score > 0.95
]
if top_labels == []:
# Find the index of the highest score
max_score_index = model_out["scores"].index(max(model_out["scores"]))
# Return the label with the highest score
return model_out["labels"][max_score_index]
return ", ".join(top_labels)
def remove_spaces_after_comma(s):
parts = s.split(",")
parts = [part.strip() for part in parts]
return ",".join(parts)
# Function to handle saving data
def save_data(orig_user_email, constituent_email, labels, user_response, current_user):
# save the data to the database
# orig_user_email should have volley 0
# constituent_email should have volley 1
# user_response should have volley 2
# app_id, org_id, and person_id should be 0
# subject should be "Email Classification and Response Tracking"
# body should be the original email
db_connection = mysql.connector.connect(
host=db_host,
user=db_user,
password=db_pass,
database=db_name,
)
db_cursor = db_connection.cursor()
if current_user == "Sheryl Springer":
person_id = 11021
elif current_user == "Diane Taylor":
person_id = 11023
elif current_user == "Ann E. Belyea":
person_id = 11025
elif current_user == "Marcelo Mejia":
person_id = 11027
elif current_user == "Rishi Vasudeva":
person_id = 11029
try:
message_id = 0
if orig_user_email != "":
db_cursor.execute(
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)",
(ORG_ID, person_id, orig_user_email, message_id),
)
#
message_id = db_cursor.lastrowid
db_cursor.execute(
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)",
(ORG_ID, constituent_email, message_id),
)
message_id = db_cursor.lastrowid
db_cursor.execute(
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)",
(ORG_ID, person_id, user_response, message_id),
)
# insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email
# if there is a comma, remove all spaces after the comma
labels = remove_spaces_after_comma(labels)
labels = labels.split(",")
for label in labels:
label_exists = db_cursor.execute(
"SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s",
(label,),
)
label_exists = db_cursor.fetchall()
if label_exists:
message_id = db_cursor.execute(
"SELECT id FROM radmap_frog12.messages WHERE body = %s",
(constituent_email,),
)
message_id = db_cursor.fetchall()
db_cursor.execute(
"INSERT INTO radmap_frog12.message_category_associations (message_id, message_category_id) VALUES (%s, %s)",
(message_id[0][0], label_exists[0][0]),
)
db_connection.commit()
return "Response successfully saved to database"
except Exception as e:
print(e)
db_connection.rollback()
return "Error saving data to database"
# read auth from env vars
auth_username = os.environ.get("AUTH_USERNAME")
auth_password = os.environ.get("AUTH_PASSWORD")
# Define your username and password pairs
auth = [(auth_username, auth_password)]
# Start building the Gradio interface
# Start building the Gradio interface with two columns
with gr.Blocks(theme=gr.themes.Soft()) as app:
with gr.Row():
gr.Markdown("## Campaign Messaging Assistant")
with gr.Row():
with gr.Column():
current_user = gr.Dropdown(
label="Current User",
choices=[
"Sheryl Springer",
"Ann E. Belyea",
"Marcelo Mejia",
"Rishi Vasudeva",
"Diane Taylor",
],
)
email_labels_input = gr.Markdown(
"## Message Category Library\n ### " + ", ".join(potential_labels),
)
original_email_input = gr.TextArea(
placeholder="Enter the original email sent by you",
label="Your Original Email (if any)",
)
spacer1 = gr.Label(visible=False)
constituent_response_input = gr.TextArea(
placeholder="Enter the incoming message",
label="Incoming Message (may be a response to original email)",
lines=15,
)
classify_button = gr.Button("Process Message", variant="primary")
with gr.Column():
classification_output = gr.TextArea(
label="Suggested Message Categories (modify as needed). Separate categories with commas",
lines=1,
interactive=True,
)
spacer2 = gr.Label(visible=False)
user_response_input = gr.TextArea(
placeholder="Enter your response to the constituent",
label="Suggested Response (modify as needed)",
lines=25,
)
save_button = gr.Button("Save Response", variant="primary")
save_output = gr.Label(label="Backend Response")
# Define button actions
classify_button.click(
fn=classify_email,
inputs=constituent_response_input,
outputs=classification_output,
)
save_button.click(
fn=save_data,
inputs=[
original_email_input,
constituent_response_input,
classification_output,
user_response_input,
current_user,
],
outputs=save_output,
)
# Launch the app
app.launch(auth=auth, debug=True)