import gradio as gr import mysql.connector import os # Use a pipeline as a high-level helper from transformers import pipeline classifier_model = pipeline( "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" ) # get db info from env vars db_host = os.environ.get("DB_HOST") db_user = os.environ.get("DB_USER") db_pass = os.environ.get("DB_PASS") db_name = os.environ.get("DB_NAME") db_connection = mysql.connector.connect( host=db_host, user=db_user, password=db_pass, database=db_name, ) db_cursor = db_connection.cursor() ORG_ID = 731 potential_labels = [] def get_potential_labels(): # get potential labels from db global potential_labels potential_labels = db_cursor.execute( "SELECT message_category_name FROM radmap_frog12.message_categorys" ) potential_labels = db_cursor.fetchall() potential_labels = [label[0] for label in potential_labels] return potential_labels potential_labels = get_potential_labels() # Function to handle the classification def classify_email(constituent_email): potential_labels = get_potential_labels() print("classifying email") model_out = classifier_model(constituent_email, potential_labels, multi_label=True) print("classification complete") top_labels = [ label for label, score in zip(model_out["labels"], model_out["scores"]) if score > 0.95 ] if top_labels == []: # Find the index of the highest score max_score_index = model_out["scores"].index(max(model_out["scores"])) # Return the label with the highest score return model_out["labels"][max_score_index] return ", ".join(top_labels) def remove_spaces_after_comma(s): parts = s.split(",") parts = [part.strip() for part in parts] return ",".join(parts) # Function to handle saving data def save_data(orig_user_email, constituent_email, labels, user_response, current_user): # save the data to the database # orig_user_email should have volley 0 # constituent_email should have volley 1 # user_response should have volley 2 # app_id, org_id, and person_id should be 0 # subject should be "Email Classification and Response Tracking" # body should be the original email db_connection = mysql.connector.connect( host=db_host, user=db_user, password=db_pass, database=db_name, ) db_cursor = db_connection.cursor() if current_user == "Sheryl Springer": person_id = 11021 elif current_user == "Diane Taylor": person_id = 11023 elif current_user == "Ann E. Belyea": person_id = 11025 elif current_user == "Marcelo Mejia": person_id = 11027 elif current_user == "Rishi Vasudeva": person_id = 11029 try: message_id = 0 if orig_user_email != "": db_cursor.execute( "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", (ORG_ID, person_id, orig_user_email, message_id), ) # message_id = db_cursor.lastrowid db_cursor.execute( "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", (ORG_ID, constituent_email, message_id), ) message_id = db_cursor.lastrowid db_cursor.execute( "INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, previous_message_id) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", (ORG_ID, person_id, user_response, message_id), ) # insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email # if there is a comma, remove all spaces after the comma labels = remove_spaces_after_comma(labels) labels = labels.split(",") for label in labels: label_exists = db_cursor.execute( "SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", (label,), ) label_exists = db_cursor.fetchall() if label_exists: message_id = db_cursor.execute( "SELECT id FROM radmap_frog12.messages WHERE body = %s", (constituent_email,), ) message_id = db_cursor.fetchall() db_cursor.execute( "INSERT INTO radmap_frog12.message_category_associations (message_id, message_category_id) VALUES (%s, %s)", (message_id[0][0], label_exists[0][0]), ) db_connection.commit() return "Response successfully saved to database" except Exception as e: print(e) db_connection.rollback() return "Error saving data to database" # read auth from env vars auth_username = os.environ.get("AUTH_USERNAME") auth_password = os.environ.get("AUTH_PASSWORD") # Define your username and password pairs auth = [(auth_username, auth_password)] # Start building the Gradio interface # Start building the Gradio interface with two columns with gr.Blocks(theme=gr.themes.Soft()) as app: with gr.Row(): gr.Markdown("## Campaign Messaging Assistant") with gr.Row(): with gr.Column(): current_user = gr.Dropdown( label="Current User", choices=[ "Sheryl Springer", "Ann E. Belyea", "Marcelo Mejia", "Rishi Vasudeva", "Diane Taylor", ], ) email_labels_input = gr.Markdown( "## Message Category Library\n ### " + ", ".join(potential_labels), ) original_email_input = gr.TextArea( placeholder="Enter the original email sent by you", label="Your Original Email (if any)", ) spacer1 = gr.Label(visible=False) constituent_response_input = gr.TextArea( placeholder="Enter the incoming message", label="Incoming Message (may be a response to original email)", lines=15, ) classify_button = gr.Button("Process Message", variant="primary") with gr.Column(): classification_output = gr.TextArea( label="Suggested Message Categories (modify as needed). Separate categories with commas", lines=1, interactive=True, ) spacer2 = gr.Label(visible=False) user_response_input = gr.TextArea( placeholder="Enter your response to the constituent", label="Suggested Response (modify as needed)", lines=25, ) save_button = gr.Button("Save Response", variant="primary") save_output = gr.Label(label="Backend Response") # Define button actions classify_button.click( fn=classify_email, inputs=constituent_response_input, outputs=classification_output, ) save_button.click( fn=save_data, inputs=[ original_email_input, constituent_response_input, classification_output, user_response_input, current_user, ], outputs=save_output, ) # Launch the app app.launch(auth=auth, debug=True)