|
import streamlit as st
|
|
import sparknlp
|
|
import os
|
|
import pandas as pd
|
|
|
|
from sparknlp.base import *
|
|
from sparknlp.annotator import *
|
|
from pyspark.ml import Pipeline
|
|
from sparknlp.pretrained import PretrainedPipeline
|
|
from annotated_text import annotated_text
|
|
from streamlit_tags import st_tags
|
|
|
|
|
|
st.set_page_config(
|
|
layout="wide",
|
|
initial_sidebar_state="auto"
|
|
)
|
|
|
|
|
|
st.markdown("""
|
|
<style>
|
|
.main-title {
|
|
font-size: 36px;
|
|
color: #4A90E2;
|
|
font-weight: bold;
|
|
text-align: center;
|
|
}
|
|
.section {
|
|
background-color: #f9f9f9;
|
|
padding: 10px;
|
|
border-radius: 10px;
|
|
margin-top: 10px;
|
|
}
|
|
.section p, .section ul {
|
|
color: #666666;
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
@st.cache_resource
|
|
def init_spark():
|
|
return sparknlp.start()
|
|
|
|
@st.cache_resource
|
|
def create_pipeline(zeroShotLables=['']):
|
|
document_assembler = DocumentAssembler() \
|
|
.setInputCol('text') \
|
|
.setOutputCol('document')
|
|
|
|
tokenizer = Tokenizer() \
|
|
.setInputCols(['document']) \
|
|
.setOutputCol('token')
|
|
|
|
zeroShotClassifier = XlmRoBertaForZeroShotClassification \
|
|
.pretrained('xlm_roberta_large_zero_shot_classifier_xnli_anli', 'xx') \
|
|
.setInputCols(['token', 'document']) \
|
|
.setOutputCol('class') \
|
|
.setCaseSensitive(False) \
|
|
.setMaxSentenceLength(512) \
|
|
.setCandidateLabels(zeroShotLables)
|
|
|
|
pipeline = Pipeline(stages=[document_assembler, tokenizer, zeroShotClassifier])
|
|
return pipeline
|
|
|
|
def fit_data(pipeline, data):
|
|
empty_df = spark.createDataFrame([['']]).toDF('text')
|
|
pipeline_model = pipeline.fit(empty_df)
|
|
model = LightPipeline(pipeline_model)
|
|
result = model.fullAnnotate(data)
|
|
return result
|
|
|
|
def annotate(data):
|
|
document, chunks, labels = data["Document"], data["NER Chunk"], data["NER Label"]
|
|
annotated_words = []
|
|
for chunk, label in zip(chunks, labels):
|
|
parts = document.split(chunk, 1)
|
|
if parts[0]:
|
|
annotated_words.append(parts[0])
|
|
annotated_words.append((chunk, label))
|
|
document = parts[1]
|
|
if document:
|
|
annotated_words.append(document)
|
|
annotated_text(*annotated_words)
|
|
|
|
tasks_models_descriptions = {
|
|
"Zero-Shot Classification": {
|
|
"models": ["xlm_roberta_large_zero_shot_classifier_xnli_anli"],
|
|
"description": "The 'xlm_roberta_large_zero_shot_classifier_xnli_anli' model provides flexible text classification without needing training data for specific categories. It is ideal for dynamic scenarios where text needs to be categorized into topics like urgent issues, technology, or sports without prior labeling."
|
|
}
|
|
}
|
|
|
|
|
|
task = st.sidebar.selectbox("Choose the task", list(tasks_models_descriptions.keys()))
|
|
model = st.sidebar.selectbox("Choose the pretrained model", tasks_models_descriptions[task]["models"], help="For more info about the models visit: https://sparknlp.org/models")
|
|
|
|
|
|
link = """
|
|
<a href="https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/357691d18373d6e8f13b5b1015137a398fd0a45f/Spark_NLP_Udemy_MOOC/Open_Source/17.01.Transformers-based_Embeddings.ipynb#L103">
|
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
|
|
</a>
|
|
"""
|
|
st.sidebar.markdown('Reference notebook:')
|
|
st.sidebar.markdown(link, unsafe_allow_html=True)
|
|
|
|
|
|
title, sub_title = (f'DeBERTa for {task}', tasks_models_descriptions[task]["description"])
|
|
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True)
|
|
container = st.container(border=True)
|
|
container.write(sub_title)
|
|
|
|
|
|
examples_mapping = {
|
|
"Zero-Shot Classification" : [
|
|
"In today’s world, staying updated with urgent information is crucial as events can unfold rapidly and require immediate attention.",
|
|
"Mobile technology has become indispensable, allowing us to access news, updates, and connect with others no matter where we are.",
|
|
"For those who love to travel, the convenience of mobile apps has transformed how we plan and experience trips, providing real-time updates on flights, accommodations, and local attractions.",
|
|
"The entertainment industry continually offers new movies that captivate audiences with their storytelling and visuals, providing a wide range of genres to suit every taste.",
|
|
"Music is an integral part of modern life, with streaming platforms making it easy to discover new artists and enjoy favorite tunes anytime, anywhere.",
|
|
"Sports enthusiasts follow games and matches closely, with live updates and detailed statistics available at their fingertips, enhancing the excitement of every game.",
|
|
"Weather forecasts play a vital role in daily planning, offering accurate and timely information to help us prepare for various weather conditions and adjust our plans accordingly.",
|
|
"Technology continues to evolve rapidly, driving innovation across all sectors and improving our everyday lives through smarter devices, advanced software, and enhanced connectivity."
|
|
]
|
|
}
|
|
|
|
examples = examples_mapping[task]
|
|
selected_text = st.selectbox("Select an example", examples)
|
|
custom_input = st.text_input("Try it with your own Sentence!")
|
|
|
|
if task == 'Zero-Shot Classification':
|
|
zeroShotLables = ["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"]
|
|
lables = st_tags(
|
|
label='Select labels',
|
|
text='Press enter to add more',
|
|
value=zeroShotLables,
|
|
suggestions=[
|
|
"Positive", "Negative", "Neutral",
|
|
"Urgent", "Mobile", "Travel", "Movie", "Music", "Sport", "Weather", "Technology",
|
|
"Happiness", "Sadness", "Anger", "Fear", "Surprise", "Disgust",
|
|
"Informational", "Navigational", "Transactional", "Commercial Investigation",
|
|
"Politics", "Business", "Sports", "Entertainment", "Health", "Science",
|
|
"Product Quality", "Delivery Experience", "Customer Service", "Pricing", "Return Policy",
|
|
"Education", "Finance", "Lifestyle", "Fashion", "Food", "Art", "History",
|
|
"Culture", "Environment", "Real Estate", "Automotive", "Travel", "Fitness", "Career"],
|
|
maxtags = -1)
|
|
|
|
try:
|
|
text_to_analyze = custom_input if custom_input else selected_text
|
|
st.subheader('Full example text')
|
|
HTML_WRAPPER = """<div class="scroll entities" style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem; white-space:pre-wrap">{}</div>"""
|
|
st.markdown(HTML_WRAPPER.format(text_to_analyze), unsafe_allow_html=True)
|
|
except:
|
|
text_to_analyze = selected_text
|
|
|
|
|
|
spark = init_spark()
|
|
pipeline = create_pipeline(zeroShotLables)
|
|
output = fit_data(pipeline, text_to_analyze)
|
|
|
|
|
|
st.subheader("Prediction:")
|
|
st.markdown(f"Document Classified as: **{output[0]['class'][0].result}**")
|
|
|