|
import gradio as gr |
|
import mysql.connector |
|
import os |
|
|
|
|
|
from transformers import pipeline |
|
|
|
classifier_model = pipeline( |
|
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" |
|
) |
|
|
|
|
|
db_host = os.environ.get("DB_HOST") |
|
db_user = os.environ.get("DB_USER") |
|
db_pass = os.environ.get("DB_PASS") |
|
db_name = os.environ.get("DB_NAME") |
|
|
|
|
|
db_connection = mysql.connector.connect( |
|
host=db_host, |
|
user=db_user, |
|
password=db_pass, |
|
database=db_name, |
|
) |
|
|
|
db_cursor = db_connection.cursor() |
|
|
|
ORG_ID = 731 |
|
|
|
potential_labels = [] |
|
|
|
|
|
def get_potential_labels(): |
|
|
|
global potential_labels |
|
potential_labels = db_cursor.execute( |
|
"SELECT message_category_name FROM radmap_frog12.message_categorys" |
|
) |
|
|
|
potential_labels = db_cursor.fetchall() |
|
|
|
potential_labels = [label[0] for label in potential_labels] |
|
|
|
return potential_labels |
|
|
|
|
|
potential_labels = get_potential_labels() |
|
|
|
|
|
|
|
def classify_email(constituent_email): |
|
potential_labels = get_potential_labels() |
|
print("classifying email") |
|
model_out = classifier_model(constituent_email, potential_labels, multi_label=True) |
|
print("classification complete") |
|
top_labels = [ |
|
label |
|
for label, score in zip(model_out["labels"], model_out["scores"]) |
|
if score > 0.95 |
|
] |
|
if top_labels == []: |
|
|
|
max_score_index = model_out["scores"].index(max(model_out["scores"])) |
|
|
|
return model_out["labels"][max_score_index] |
|
|
|
return ", ".join(top_labels) |
|
|
|
|
|
def remove_spaces_after_comma(s): |
|
parts = s.split(",") |
|
parts = [part.strip() for part in parts] |
|
return ",".join(parts) |
|
|
|
|
|
|
|
def save_data(orig_user_email, constituent_email, labels, user_response, current_user): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
db_connection = mysql.connector.connect( |
|
host=db_host, |
|
user=db_user, |
|
password=db_pass, |
|
database=db_name, |
|
) |
|
|
|
db_cursor = db_connection.cursor() |
|
|
|
if current_user == "Sheryl Springer": |
|
person_id = 11021 |
|
elif current_user == "Diane Taylor": |
|
person_id = 11023 |
|
elif current_user == "Ann E. Belyea": |
|
person_id = 11025 |
|
elif current_user == "Marcelo Mejia": |
|
person_id = 11027 |
|
elif current_user == "Rishi Vasudeva": |
|
person_id = 11029 |
|
|
|
try: |
|
message_id = 0 |
|
if orig_user_email != "": |
|
db_cursor.execute( |
|
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", |
|
(ORG_ID, person_id, orig_user_email, message_id), |
|
) |
|
|
|
|
|
message_id = db_cursor.lastrowid |
|
|
|
db_cursor.execute( |
|
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", |
|
(ORG_ID, constituent_email, message_id), |
|
) |
|
|
|
message_id = db_cursor.lastrowid |
|
|
|
db_cursor.execute( |
|
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", |
|
(ORG_ID, person_id, user_response, message_id), |
|
) |
|
|
|
|
|
|
|
|
|
labels = remove_spaces_after_comma(labels) |
|
labels = labels.split(",") |
|
for label in labels: |
|
label_exists = db_cursor.execute( |
|
"SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", |
|
(label,), |
|
) |
|
label_exists = db_cursor.fetchall() |
|
if label_exists: |
|
message_id = db_cursor.execute( |
|
"SELECT id FROM radmap_frog12.messages WHERE body = %s", |
|
(constituent_email,), |
|
) |
|
message_id = db_cursor.fetchall() |
|
|
|
db_cursor.execute( |
|
"INSERT INTO radmap_frog12.message_category_associations (message_id, message_category_id) VALUES (%s, %s)", |
|
(message_id[0][0], label_exists[0][0]), |
|
) |
|
|
|
db_connection.commit() |
|
|
|
return "Response successfully saved to database" |
|
|
|
except Exception as e: |
|
print(e) |
|
db_connection.rollback() |
|
return "Error saving data to database" |
|
|
|
|
|
|
|
auth_username = os.environ.get("AUTH_USERNAME") |
|
auth_password = os.environ.get("AUTH_PASSWORD") |
|
|
|
|
|
auth = [(auth_username, auth_password)] |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as app: |
|
with gr.Row(): |
|
gr.Markdown("## Campaign Messaging Assistant") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
current_user = gr.Dropdown( |
|
label="Current User", |
|
choices=[ |
|
"Sheryl Springer", |
|
"Ann E. Belyea", |
|
"Marcelo Mejia", |
|
"Rishi Vasudeva", |
|
"Diane Taylor", |
|
], |
|
) |
|
|
|
email_labels_input = gr.Markdown( |
|
"## Message Category Library\n ### " + ", ".join(potential_labels), |
|
) |
|
|
|
original_email_input = gr.TextArea( |
|
placeholder="Enter the original email sent by you", |
|
label="Your Original Email (if any)", |
|
) |
|
|
|
spacer1 = gr.Label(visible=False) |
|
|
|
constituent_response_input = gr.TextArea( |
|
placeholder="Enter the incoming message", |
|
label="Incoming Message (may be a response to original email)", |
|
lines=15, |
|
) |
|
|
|
classify_button = gr.Button("Process Message", variant="primary") |
|
|
|
with gr.Column(): |
|
classification_output = gr.TextArea( |
|
label="Suggested Message Categories (modify as needed). Separate categories with commas", |
|
lines=1, |
|
interactive=True, |
|
) |
|
|
|
spacer2 = gr.Label(visible=False) |
|
|
|
user_response_input = gr.TextArea( |
|
placeholder="Enter your response to the constituent", |
|
label="Suggested Response (modify as needed)", |
|
lines=25, |
|
) |
|
|
|
save_button = gr.Button("Save Response", variant="primary") |
|
save_output = gr.Label(label="Backend Response") |
|
|
|
|
|
classify_button.click( |
|
fn=classify_email, |
|
inputs=constituent_response_input, |
|
outputs=classification_output, |
|
) |
|
|
|
save_button.click( |
|
fn=save_data, |
|
inputs=[ |
|
original_email_input, |
|
constituent_response_input, |
|
classification_output, |
|
user_response_input, |
|
current_user, |
|
], |
|
outputs=save_output, |
|
) |
|
|
|
|
|
app.launch(auth=auth, debug=True) |
|
|