Spaces:
Build error
Build error
| import gradio as gr | |
| from app.utils import ( | |
| create_input_instruction, | |
| format_prediction_ouptut, | |
| display_sentiment_score_table, | |
| sentiment_flow_plot, | |
| EXAMPLE_CONVERSATIONS, | |
| ) | |
| import sys | |
| sys.path.insert(0, "../") # neccesary to load modules outside of app | |
| from app import deberta_model, tokenizer | |
| from preprocessing import preprocess | |
| from Model.DeBERTa.deberta import predict, decode_deberta_label | |
| def deberta_preprocess(input): | |
| result = preprocess.process_user_input(input) | |
| if not result["success"]: | |
| raise gr.Error(result["message"]) | |
| data = result["data"] | |
| speakers = [item[1] for item in data] | |
| messages = [item[2] for item in data] | |
| return speakers, messages | |
| def deberta_classifier(input): | |
| speakers, messages = deberta_preprocess(input) | |
| predictions = predict(deberta_model, tokenizer, messages) | |
| # Assuming that there's only one conversation | |
| labels = [decode_deberta_label(pred) for pred in predictions] | |
| output = format_prediction_ouptut(speakers, messages, labels) | |
| return output | |
| def deberta_ui(): | |
| with gr.Blocks() as deberta_model: | |
| gr.Markdown( | |
| """ | |
| # Deberta | |
| Building upon the DeBERTa architecture, the model was customized and | |
| retrained on Epik data to classify messages between Visitors and Agents into | |
| corresponding sentiment labels. At the time of training by the team prior to | |
| the Fall 2023 semester, the model was trained on 15 labels, including | |
| Openness, Anxiety, Confusion, Disapproval, Remorse, Accusation, Denial, | |
| Obscenity, Disinterest, Annoyance, Information, Greeting, Interest, | |
| Curiosity, or Acceptance. | |
| The primary difference between DeBERTa and COSMIC is that while DeBERTa's | |
| prediction is solely based on its own context, COSMIC uses the context of | |
| the entire conversation (i.e., all messages from the chat history of the | |
| conversation). | |
| """ | |
| ) | |
| create_input_instruction() | |
| with gr.Row(): | |
| with gr.Column(): | |
| example_dropdown = gr.Dropdown( | |
| choices=["-- Not Selected --"] + list(EXAMPLE_CONVERSATIONS.keys()), | |
| value="-- Not Selected --", | |
| label="Select an example", | |
| ) | |
| gr.Markdown('<p style="text-align: center;color: gray;">--- OR ---</p>') | |
| conversation_input = gr.TextArea( | |
| value="", | |
| label="Input you conversation", | |
| placeholder="Plese input your conversation here", | |
| lines=15, | |
| max_lines=15, | |
| ) | |
| def on_example_change(input): | |
| if input in EXAMPLE_CONVERSATIONS: | |
| return EXAMPLE_CONVERSATIONS[input] | |
| return "" | |
| example_dropdown.input( | |
| on_example_change, | |
| inputs=example_dropdown, | |
| outputs=conversation_input, | |
| ) | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| value="", | |
| label="Predicted Sentiment Labels", | |
| lines=22, | |
| max_lines=22, | |
| interactive=False, | |
| ) | |
| submit_btn = gr.Button(value="Submit") | |
| submit_btn.click(deberta_classifier, conversation_input, output) | |
| # reset the output whenever a change in the input is detected | |
| conversation_input.change(lambda x: "", conversation_input, output) | |
| gr.Markdown("# Sentiment Flow Plot") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| display_sentiment_score_table() | |
| with gr.Column(scale=2): | |
| plot_box = gr.Plot(label="Analysis Plot") | |
| plot_btn = gr.Button(value="Plot Sentiment Flow") | |
| plot_btn.click(sentiment_flow_plot, inputs=[output], outputs=[plot_box]) | |
| # reset all outputs whenever a change in the input is detected | |
| conversation_input.change( | |
| lambda x: ("", None), | |
| conversation_input, | |
| outputs=[output, plot_box], | |
| ) | |
| return deberta_model | |