Spaces:
Build error
Build error
import gradio as gr | |
from app.utils import ( | |
create_input_instruction, | |
format_prediction_ouptut, | |
display_sentiment_score_table, | |
sentiment_flow_plot, | |
EXAMPLE_CONVERSATIONS, | |
) | |
import sys | |
sys.path.insert(0, "../") # neccesary to load modules outside of app | |
from app import deberta_model, tokenizer | |
from preprocessing import preprocess | |
from Model.DeBERTa.deberta import predict, decode_deberta_label | |
def deberta_preprocess(input): | |
result = preprocess.process_user_input(input) | |
if not result["success"]: | |
raise gr.Error(result["message"]) | |
data = result["data"] | |
speakers = [item[1] for item in data] | |
messages = [item[2] for item in data] | |
return speakers, messages | |
def deberta_classifier(input): | |
speakers, messages = deberta_preprocess(input) | |
predictions = predict(deberta_model, tokenizer, messages) | |
# Assuming that there's only one conversation | |
labels = [decode_deberta_label(pred) for pred in predictions] | |
output = format_prediction_ouptut(speakers, messages, labels) | |
return output | |
def deberta_ui(): | |
with gr.Blocks() as deberta_model: | |
gr.Markdown( | |
""" | |
# Deberta | |
Building upon the DeBERTa architecture, the model was customized and | |
retrained on Epik data to classify messages between Visitors and Agents into | |
corresponding sentiment labels. At the time of training by the team prior to | |
the Fall 2023 semester, the model was trained on 15 labels, including | |
Openness, Anxiety, Confusion, Disapproval, Remorse, Accusation, Denial, | |
Obscenity, Disinterest, Annoyance, Information, Greeting, Interest, | |
Curiosity, or Acceptance. | |
The primary difference between DeBERTa and COSMIC is that while DeBERTa's | |
prediction is solely based on its own context, COSMIC uses the context of | |
the entire conversation (i.e., all messages from the chat history of the | |
conversation). | |
""" | |
) | |
create_input_instruction() | |
with gr.Row(): | |
with gr.Column(): | |
example_dropdown = gr.Dropdown( | |
choices=["-- Not Selected --"] + list(EXAMPLE_CONVERSATIONS.keys()), | |
value="-- Not Selected --", | |
label="Select an example", | |
) | |
gr.Markdown('<p style="text-align: center;color: gray;">--- OR ---</p>') | |
conversation_input = gr.TextArea( | |
value="", | |
label="Input you conversation", | |
placeholder="Plese input your conversation here", | |
lines=15, | |
max_lines=15, | |
) | |
def on_example_change(input): | |
if input in EXAMPLE_CONVERSATIONS: | |
return EXAMPLE_CONVERSATIONS[input] | |
return "" | |
example_dropdown.input( | |
on_example_change, | |
inputs=example_dropdown, | |
outputs=conversation_input, | |
) | |
with gr.Column(): | |
output = gr.Textbox( | |
value="", | |
label="Predicted Sentiment Labels", | |
lines=22, | |
max_lines=22, | |
interactive=False, | |
) | |
submit_btn = gr.Button(value="Submit") | |
submit_btn.click(deberta_classifier, conversation_input, output) | |
# reset the output whenever a change in the input is detected | |
conversation_input.change(lambda x: "", conversation_input, output) | |
gr.Markdown("# Sentiment Flow Plot") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
display_sentiment_score_table() | |
with gr.Column(scale=2): | |
plot_box = gr.Plot(label="Analysis Plot") | |
plot_btn = gr.Button(value="Plot Sentiment Flow") | |
plot_btn.click(sentiment_flow_plot, inputs=[output], outputs=[plot_box]) | |
# reset all outputs whenever a change in the input is detected | |
conversation_input.change( | |
lambda x: ("", None), | |
conversation_input, | |
outputs=[output, plot_box], | |
) | |
return deberta_model | |