import gradio as gr import mysql.connector import os # Use a pipeline as a high-level helper from transformers import pipeline classifier_model = pipeline( "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" ) # get db info from env vars db_host = os.environ.get("DB_HOST") db_user = os.environ.get("DB_USER") db_pass = os.environ.get("DB_PASS") db_name = os.environ.get("DB_NAME") db_connection = mysql.connector.connect( host=db_host, user=db_user, password=db_pass, database=db_name, ) db_cursor = db_connection.cursor() def get_potential_labels(): # get potential labels from db potential_labels = db_cursor.execute( "SELECT message_category_name FROM radmap_frog12.message_categorys" ) potential_labels = db_cursor.fetchall() potential_labels = [label[0] for label in potential_labels] return potential_labels potential_labels = get_potential_labels() # Function to handle the classification def classify_email(constituent_email): print("classifying email") model_out = classifier_model(constituent_email, potential_labels, multi_label=True) print("classification complete") top_labels = [ label for label, score in zip(model_out["labels"], model_out["scores"]) if score > 0.95 ] if top_labels == []: # Find the index of the highest score max_score_index = model_out["scores"].index(max(model_out["scores"])) # Return the label with the highest score return model_out["labels"][max_score_index] return ", ".join(top_labels) # Function to handle saving data def save_data(orig_user_email, constituent_email, labels, user_response): # save the data to the database # orig_user_email should have volley 0 # constituent_email should have volley 1 # user_response should have volley 2 # app_id, org_id, and person_id should be 0 # subject should be "Email Classification and Response Tracking" # body should be the original email db_connection = mysql.connector.connect( host=db_host, user=db_user, password=db_pass, database=db_name, ) db_cursor = db_connection.cursor() try: print("saving first email") db_cursor.execute( "INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 0)", (orig_user_email,), ) print("saving constituent email") db_cursor.execute( "INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 1)", (constituent_email,), ) print("saving user response") db_cursor.execute( "INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 2)", (user_response,), ) # insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email labels = labels.split(", ") for label in labels: print("saving label: " + label) label_exists = db_cursor.execute( "SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", (label,), ) label_exists = db_cursor.fetchall() if label_exists: print("label exists") db_cursor.execute( "INSERT INTO message_category_associations (message_id, message_category_id) VALUES ((SELECT id FROM messages WHERE body = %s), %i)", (constituent_email, label_exists[0][0]), ) print("label saved") db_connection.commit() return "Data successfully saved to database" except Exception as e: print(e) db_connection.rollback() return "Error saving data to database" # read auth from env vars auth_username = os.environ.get("AUTH_USERNAME") auth_password = os.environ.get("AUTH_PASSWORD") # Define your username and password pairs auth = [(auth_username, auth_password)] # Start building the Gradio interface # Start building the Gradio interface with two columns with gr.Blocks() as app: with gr.Row(): gr.Markdown("## Email Classification and Response Tracking") with gr.Row(): with gr.Column(): email_labels_input = gr.Markdown( "## Valid Email Labels\n ### " + ", ".join(potential_labels), ) original_email_input = gr.TextArea( placeholder="Enter the original email sent by you", label="Your Original Email", ) spacer1 = gr.Label(visible=False) constituent_response_input = gr.TextArea( placeholder="Enter the constituent's response", label="Constituent's Response", lines=15, ) classify_button = gr.Button("Classify Email") with gr.Column(): classification_output = gr.TextArea( label="Current Email Labels", lines=1, interactive=True, ) spacer2 = gr.Label(visible=False) user_response_input = gr.TextArea( placeholder="Enter your response to the constituent", label="Your Response", lines=25, ) save_button = gr.Button("Save Data") save_output = gr.Label(label="Backend Response") # Define button actions classify_button.click( fn=classify_email, inputs=constituent_response_input, outputs=classification_output, ) save_button.click( fn=save_data, inputs=[ original_email_input, constituent_response_input, classification_output, user_response_input, ], outputs=save_output, ) # Launch the app app.launch(auth=auth, debug=True)