Spaces:
Sleeping
Sleeping
import gradio as gr | |
import mysql.connector | |
import os | |
# Use a pipeline as a high-level helper | |
from transformers import pipeline | |
classifier_model = pipeline( | |
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1" | |
) | |
# get db info from env vars | |
db_host = os.environ.get("DB_HOST") | |
db_user = os.environ.get("DB_USER") | |
db_pass = os.environ.get("DB_PASS") | |
db_name = os.environ.get("DB_NAME") | |
db_connection = mysql.connector.connect( | |
host=db_host, | |
user=db_user, | |
password=db_pass, | |
database=db_name, | |
) | |
db_cursor = db_connection.cursor() | |
def get_potential_labels(): | |
# get potential labels from db | |
potential_labels = db_cursor.execute( | |
"SELECT message_category_name FROM radmap_frog12.message_categorys" | |
) | |
potential_labels = db_cursor.fetchall() | |
potential_labels = [label[0] for label in potential_labels] | |
return potential_labels | |
potential_labels = get_potential_labels() | |
# Function to handle the classification | |
def classify_email(constituent_email): | |
print("classifying email") | |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True) | |
print("classification complete") | |
top_labels = [ | |
label | |
for label, score in zip(model_out["labels"], model_out["scores"]) | |
if score > 0.95 | |
] | |
if top_labels == []: | |
# Find the index of the highest score | |
max_score_index = model_out["scores"].index(max(model_out["scores"])) | |
# Return the label with the highest score | |
return model_out["labels"][max_score_index] | |
return ", ".join(top_labels) | |
# Function to handle saving data | |
def save_data(orig_user_email, constituent_email, labels, user_response): | |
# save the data to the database | |
# orig_user_email should have volley 0 | |
# constituent_email should have volley 1 | |
# user_response should have volley 2 | |
# app_id, org_id, and person_id should be 0 | |
# subject should be "Email Classification and Response Tracking" | |
# body should be the original email | |
db_connection = mysql.connector.connect( | |
host=db_host, | |
user=db_user, | |
password=db_pass, | |
database=db_name, | |
) | |
db_cursor = db_connection.cursor() | |
try: | |
print("saving first email") | |
db_cursor.execute( | |
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 0)", | |
(orig_user_email,), | |
) | |
print("saving constituent email") | |
db_cursor.execute( | |
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 1)", | |
(constituent_email,), | |
) | |
print("saving user response") | |
db_cursor.execute( | |
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 2)", | |
(user_response,), | |
) | |
# insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email | |
labels = labels.split(", ") | |
for label in labels: | |
print("saving label: " + label) | |
label_exists = db_cursor.execute( | |
"SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s", | |
(label,), | |
) | |
label_exists = db_cursor.fetchall() | |
if label_exists: | |
print("label exists") | |
db_cursor.execute( | |
"INSERT INTO message_category_associations (message_id, message_category_id) VALUES ((SELECT id FROM messages WHERE body = %s), %i)", | |
(constituent_email, label_exists[0][0]), | |
) | |
print("label saved") | |
db_connection.commit() | |
return "Data successfully saved to database" | |
except Exception as e: | |
print(e) | |
db_connection.rollback() | |
return "Error saving data to database" | |
# read auth from env vars | |
auth_username = os.environ.get("AUTH_USERNAME") | |
auth_password = os.environ.get("AUTH_PASSWORD") | |
# Define your username and password pairs | |
auth = [(auth_username, auth_password)] | |
# Start building the Gradio interface | |
# Start building the Gradio interface with two columns | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown("## Email Classification and Response Tracking") | |
with gr.Row(): | |
with gr.Column(): | |
email_labels_input = gr.Markdown( | |
"## Valid Email Labels\n ### " + ", ".join(potential_labels), | |
) | |
original_email_input = gr.TextArea( | |
placeholder="Enter the original email sent by you", | |
label="Your Original Email", | |
) | |
spacer1 = gr.Label(visible=False) | |
constituent_response_input = gr.TextArea( | |
placeholder="Enter the constituent's response", | |
label="Constituent's Response", | |
lines=15, | |
) | |
classify_button = gr.Button("Classify Email") | |
with gr.Column(): | |
classification_output = gr.TextArea( | |
label="Current Email Labels", | |
lines=1, | |
interactive=True, | |
) | |
spacer2 = gr.Label(visible=False) | |
user_response_input = gr.TextArea( | |
placeholder="Enter your response to the constituent", | |
label="Your Response", | |
lines=25, | |
) | |
save_button = gr.Button("Save Data") | |
save_output = gr.Label(label="Backend Response") | |
# Define button actions | |
classify_button.click( | |
fn=classify_email, | |
inputs=constituent_response_input, | |
outputs=classification_output, | |
) | |
save_button.click( | |
fn=save_data, | |
inputs=[ | |
original_email_input, | |
constituent_response_input, | |
classification_output, | |
user_response_input, | |
], | |
outputs=save_output, | |
) | |
# Launch the app | |
app.launch(auth=auth, debug=True) | |