AMead10's picture
add in new connection
0f9c891
raw
history blame
6.57 kB
import gradio as gr
import mysql.connector
import os
# Use a pipeline as a high-level helper
from transformers import pipeline
classifier_model = pipeline(
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
)
# get db info from env vars
db_host = os.environ.get("DB_HOST")
db_user = os.environ.get("DB_USER")
db_pass = os.environ.get("DB_PASS")
db_name = os.environ.get("DB_NAME")
db_connection = mysql.connector.connect(
host=db_host,
user=db_user,
password=db_pass,
database=db_name,
)
db_cursor = db_connection.cursor()
def get_potential_labels():
# get potential labels from db
potential_labels = db_cursor.execute(
"SELECT message_category_name FROM radmap_frog12.message_categorys"
)
potential_labels = db_cursor.fetchall()
potential_labels = [label[0] for label in potential_labels]
return potential_labels
potential_labels = get_potential_labels()
# Function to handle the classification
def classify_email(constituent_email):
print("classifying email")
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
print("classification complete")
top_labels = [
label
for label, score in zip(model_out["labels"], model_out["scores"])
if score > 0.95
]
if top_labels == []:
# Find the index of the highest score
max_score_index = model_out["scores"].index(max(model_out["scores"]))
# Return the label with the highest score
return model_out["labels"][max_score_index]
return ", ".join(top_labels)
# Function to handle saving data
def save_data(orig_user_email, constituent_email, labels, user_response):
# save the data to the database
# orig_user_email should have volley 0
# constituent_email should have volley 1
# user_response should have volley 2
# app_id, org_id, and person_id should be 0
# subject should be "Email Classification and Response Tracking"
# body should be the original email
db_connection = mysql.connector.connect(
host=db_host,
user=db_user,
password=db_pass,
database=db_name,
)
db_cursor = db_connection.cursor()
try:
print("saving first email")
db_cursor.execute(
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 0)",
(orig_user_email,),
)
print("saving constituent email")
db_cursor.execute(
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 1)",
(constituent_email,),
)
print("saving user response")
db_cursor.execute(
"INSERT INTO messages (app_id, org_id, person_id, communication_method_id, status_id, subject, body, send_date, message_type, volley) VALUES (0, 0, 0, 1, 1, 'Email Classification and Response Tracking', %s, NOW(), 'Email Classification and Response Tracking', 2)",
(user_response,),
)
# insert a row into the message_categorys_associations table for each valid label in labels with the message_id of the constituent_email
labels = labels.split(", ")
for label in labels:
print("saving label: " + label)
label_exists = db_cursor.execute(
"SELECT * FROM radmap_frog12.message_categorys WHERE message_category_name = %s",
(label,),
)
label_exists = db_cursor.fetchall()
if label_exists:
print("label exists")
db_cursor.execute(
"INSERT INTO message_category_associations (message_id, message_category_id) VALUES ((SELECT id FROM messages WHERE body = %s), %i)",
(constituent_email, label_exists[0][0]),
)
print("label saved")
db_connection.commit()
return "Data successfully saved to database"
except Exception as e:
print(e)
db_connection.rollback()
return "Error saving data to database"
# read auth from env vars
auth_username = os.environ.get("AUTH_USERNAME")
auth_password = os.environ.get("AUTH_PASSWORD")
# Define your username and password pairs
auth = [(auth_username, auth_password)]
# Start building the Gradio interface
# Start building the Gradio interface with two columns
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## Email Classification and Response Tracking")
with gr.Row():
with gr.Column():
email_labels_input = gr.Markdown(
"## Valid Email Labels\n ### " + ", ".join(potential_labels),
)
original_email_input = gr.TextArea(
placeholder="Enter the original email sent by you",
label="Your Original Email",
)
spacer1 = gr.Label(visible=False)
constituent_response_input = gr.TextArea(
placeholder="Enter the constituent's response",
label="Constituent's Response",
lines=15,
)
classify_button = gr.Button("Classify Email")
with gr.Column():
classification_output = gr.TextArea(
label="Current Email Labels",
lines=1,
interactive=True,
)
spacer2 = gr.Label(visible=False)
user_response_input = gr.TextArea(
placeholder="Enter your response to the constituent",
label="Your Response",
lines=25,
)
save_button = gr.Button("Save Data")
save_output = gr.Label(label="Backend Response")
# Define button actions
classify_button.click(
fn=classify_email,
inputs=constituent_response_input,
outputs=classification_output,
)
save_button.click(
fn=save_data,
inputs=[
original_email_input,
constituent_response_input,
classification_output,
user_response_input,
],
outputs=save_output,
)
# Launch the app
app.launch(auth=auth, debug=True)