Spaces:
Sleeping
Sleeping
import gradio as gr | |
import mysql.connector | |
import os | |
# Use a pipeline as a high-level helper | |
from transformers import pipeline | |
classifier_model = pipeline( | |
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" | |
) | |
# get db info from env vars | |
db_host = os.environ.get("DB_HOST") | |
db_user = os.environ.get("DB_USER") | |
db_pass = os.environ.get("DB_PASS") | |
db_name = os.environ.get("DB_NAME") | |
db_connection = mysql.connector.connect( | |
host=db_host, | |
user=db_user, | |
password=db_pass, | |
database=db_name, | |
) | |
db_cursor = db_connection.cursor() | |
ORG_ID = 731 | |
potential_labels = [] | |
def get_potential_labels(): | |
# get potential labels from db | |
global potential_labels | |
potential_labels = db_cursor.execute( | |
"SELECT message_category_name FROM radmap_frog12.message_categorys" | |
) | |
potential_labels = db_cursor.fetchall() | |
potential_labels = [label[0] for label in potential_labels] | |
return potential_labels | |
potential_labels = get_potential_labels() | |
# Function to handle the classification | |
def classify_email(constituent_email): | |
potential_labels = get_potential_labels() | |
print("classifying email") | |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True) | |
print("classification complete") | |
top_labels = [ | |
label | |
for label, score in zip(model_out["labels"], model_out["scores"]) | |
if score > 0.95 | |
] | |
if top_labels == []: | |
# Find the index of the highest score | |
max_score_index = model_out["scores"].index(max(model_out["scores"])) | |
# Return the label with the highest score | |
return model_out["labels"][max_score_index] | |
return ", ".join(top_labels) | |
def remove_spaces_after_comma(s): | |
parts = s.split(",") | |
parts = [part.strip() for part in parts] | |
return ",".join(parts) | |
# Function to handle saving data | |
def save_data(orig_user_email, constituent_email, labels, user_response, current_user): | |
# save the data to the database | |
# orig_user_email should have volley 0 | |
# constituent_email should have volley 1 | |
# user_response should have volley 2 | |
# app_id, org_id, and person_id should be 0 | |
# subject should be "Email Classification and Response Tracking" | |
# body should be the original email | |
db_connection = mysql.connector.connect( | |
host=db_host, | |
user=db_user, | |
password=db_pass, | |
database=db_name, | |
) | |
db_cursor = db_connection.cursor() | |
if current_user == "Sheryl Springer": | |
person_id = 11021 | |
elif current_user == "Diane Taylor": | |
person_id = 11023 | |
elif current_user == "Ann E. Belyea": | |
person_id = 11025 | |
elif current_user == "Marcelo Mejia": | |
person_id = 11027 | |
elif current_user == "Rishi Vasudeva": | |
person_id = 11029 | |
try: | |
volley = 1 | |
if orig_user_email != "": | |
db_cursor.execute( | |
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
(ORG_ID, person_id, orig_user_email, volley), | |
) | |
volley = 2 | |
db_cursor.execute( | |
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (345678, %s, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
(ORG_ID, constituent_email, volley), | |
) | |
db_cursor.execute( | |
"INSERT INTO radmap_frog12.messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (345678, %s, %s, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', %s)", | |
(ORG_ID, person_id, user_response, volley + 1), | |
) | |
# insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email | |
# if there is a comma, remove all spaces after the comma | |
labels = remove_spaces_after_comma(labels) | |
labels = labels.split(",") | |
for label in labels: | |
label_exists = db_cursor.execute( | |
"SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", | |
(label,), | |
) | |
label_exists = db_cursor.fetchall() | |
if label_exists: | |
message_id = db_cursor.execute( | |
"SELECT id FROM radmap_frog12.messages WHERE body = %s", | |
(constituent_email,), | |
) | |
message_id = db_cursor.fetchall() | |
db_cursor.execute( | |
"INSERT INTO radmap_frog12.message_category_associations (message_id, message_category_id) VALUES (%s, %s)", | |
(message_id[0][0], label_exists[0][0]), | |
) | |
db_connection.commit() | |
return "Response successfully saved to database" | |
except Exception as e: | |
print(e) | |
db_connection.rollback() | |
return "Error saving data to database" | |
# read auth from env vars | |
auth_username = os.environ.get("AUTH_USERNAME") | |
auth_password = os.environ.get("AUTH_PASSWORD") | |
# Define your username and password pairs | |
auth = [(auth_username, auth_password)] | |
# Start building the Gradio interface | |
# Start building the Gradio interface with two columns | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
with gr.Row(): | |
gr.Markdown("## Campaign Messaging Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
current_user = gr.Dropdown( | |
label="Current User", | |
choices=[ | |
"Sheryl Springer", | |
"Ann E. Belyea", | |
"Marcelo Mejia", | |
"Rishi Vasudeva", | |
"Diane Taylor", | |
], | |
) | |
email_labels_input = gr.Markdown( | |
"## Message Category Library\n ### " + ", ".join(potential_labels), | |
) | |
original_email_input = gr.TextArea( | |
placeholder="Enter the original email sent by you", | |
label="Your Original Email (if any)", | |
) | |
spacer1 = gr.Label(visible=False) | |
constituent_response_input = gr.TextArea( | |
placeholder="Enter the incoming message", | |
label="Incoming Message (may be a response to original email)", | |
lines=15, | |
) | |
classify_button = gr.Button("Process Message", variant="primary") | |
with gr.Column(): | |
classification_output = gr.TextArea( | |
label="Suggested Message Categories (modify as needed). Separate categories with commas", | |
lines=1, | |
interactive=True, | |
) | |
spacer2 = gr.Label(visible=False) | |
user_response_input = gr.TextArea( | |
placeholder="Enter your response to the constituent", | |
label="Suggested Response (modify as needed)", | |
lines=25, | |
) | |
save_button = gr.Button("Save Response", variant="primary") | |
save_output = gr.Label(label="Backend Response") | |
# Define button actions | |
classify_button.click( | |
fn=classify_email, | |
inputs=constituent_response_input, | |
outputs=classification_output, | |
) | |
save_button.click( | |
fn=save_data, | |
inputs=[ | |
original_email_input, | |
constituent_response_input, | |
classification_output, | |
user_response_input, | |
current_user, | |
], | |
outputs=save_output, | |
) | |
# Launch the app | |
app.launch(auth=auth, debug=True) | |