abdullahmubeen10's picture
Upload 5 files
510d114 verified
import streamlit as st
import sparknlp
import os
import pandas as pd
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
from sparknlp.pretrained import PretrainedPipeline
from annotated_text import annotated_text
from streamlit_tags import st_tags
# Page configuration
st.set_page_config(
layout="wide",
initial_sidebar_state="auto"
)
# CSS for styling
st.markdown("""
<style>
.main-title {
font-size: 36px;
color: #4A90E2;
font-weight: bold;
text-align: center;
}
.section {
background-color: #f9f9f9;
padding: 10px;
border-radius: 10px;
margin-top: 10px;
}
.section p, .section ul {
color: #666666;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def init_spark():
return sparknlp.start()
@st.cache_resource
def create_pipeline(zeroShotLables=['']):
document_assembler = DocumentAssembler() \
.setInputCol('text') \
.setOutputCol('document')
tokenizer = Tokenizer() \
.setInputCols(['document']) \
.setOutputCol('token')
zeroShotClassifier = XlmRoBertaForZeroShotClassification \
.pretrained('xlm_roberta_large_zero_shot_classifier_xnli_anli', 'xx') \
.setInputCols(['token', 'document']) \
.setOutputCol('class') \
.setCaseSensitive(False) \
.setMaxSentenceLength(512) \
.setCandidateLabels(zeroShotLables)
pipeline = Pipeline(stages=[document_assembler, tokenizer, zeroShotClassifier])
return pipeline
def fit_data(pipeline, data):
empty_df = spark.createDataFrame([['']]).toDF('text')
pipeline_model = pipeline.fit(empty_df)
model = LightPipeline(pipeline_model)
result = model.fullAnnotate(data)
return result
def annotate(data):
document, chunks, labels = data["Document"], data["NER Chunk"], data["NER Label"]
annotated_words = []
for chunk, label in zip(chunks, labels):
parts = document.split(chunk, 1)
if parts[0]:
annotated_words.append(parts[0])
annotated_words.append((chunk, label))
document = parts[1]
if document:
annotated_words.append(document)
annotated_text(*annotated_words)
tasks_models_descriptions = {
"Zero-Shot Classification": {
"models": ["xlm_roberta_large_zero_shot_classifier_xnli_anli"],
"description": "The 'xlm_roberta_large_zero_shot_classifier_xnli_anli' model provides flexible text classification without needing training data for specific categories. It is ideal for dynamic scenarios where text needs to be categorized into topics like urgent issues, technology, or sports without prior labeling."
}
}
# Sidebar content
task = st.sidebar.selectbox("Choose the task", list(tasks_models_descriptions.keys()))
model = st.sidebar.selectbox("Choose the pretrained model", tasks_models_descriptions[task]["models"], help="For more info about the models visit: https://sparknlp.org/models")
# Reference notebook link in sidebar
link = """
<a href="https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/357691d18373d6e8f13b5b1015137a398fd0a45f/Spark_NLP_Udemy_MOOC/Open_Source/17.01.Transformers-based_Embeddings.ipynb#L103">
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
</a>
"""
st.sidebar.markdown('Reference notebook:')
st.sidebar.markdown(link, unsafe_allow_html=True)
# Page content
title, sub_title = (f'DeBERTa for {task}', tasks_models_descriptions[task]["description"])
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True)
container = st.container(border=True)
container.write(sub_title)
# Load examples
examples_mapping = {
"Zero-Shot Classification" : [
"In today’s world, staying updated with urgent information is crucial as events can unfold rapidly and require immediate attention.", # Urgent
"Mobile technology has become indispensable, allowing us to access news, updates, and connect with others no matter where we are.", # Mobile
"For those who love to travel, the convenience of mobile apps has transformed how we plan and experience trips, providing real-time updates on flights, accommodations, and local attractions.", # Travel
"The entertainment industry continually offers new movies that captivate audiences with their storytelling and visuals, providing a wide range of genres to suit every taste.", # Movie
"Music is an integral part of modern life, with streaming platforms making it easy to discover new artists and enjoy favorite tunes anytime, anywhere.", # Music
"Sports enthusiasts follow games and matches closely, with live updates and detailed statistics available at their fingertips, enhancing the excitement of every game.", # Sport
"Weather forecasts play a vital role in daily planning, offering accurate and timely information to help us prepare for various weather conditions and adjust our plans accordingly.", # Weather
"Technology continues to evolve rapidly, driving innovation across all sectors and improving our everyday lives through smarter devices, advanced software, and enhanced connectivity." # Technology
]
}
examples = examples_mapping[task]
selected_text = st.selectbox("Select an example", examples)
custom_input = st.text_input("Try it with your own Sentence!")
if task == 'Zero-Shot Classification':
zeroShotLables = ["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"]
lables = st_tags(
label='Select labels',
text='Press enter to add more',
value=zeroShotLables,
suggestions=[
"Positive", "Negative", "Neutral",
"Urgent", "Mobile", "Travel", "Movie", "Music", "Sport", "Weather", "Technology",
"Happiness", "Sadness", "Anger", "Fear", "Surprise", "Disgust",
"Informational", "Navigational", "Transactional", "Commercial Investigation",
"Politics", "Business", "Sports", "Entertainment", "Health", "Science",
"Product Quality", "Delivery Experience", "Customer Service", "Pricing", "Return Policy",
"Education", "Finance", "Lifestyle", "Fashion", "Food", "Art", "History",
"Culture", "Environment", "Real Estate", "Automotive", "Travel", "Fitness", "Career"],
maxtags = -1)
try:
text_to_analyze = custom_input if custom_input else selected_text
st.subheader('Full example text')
HTML_WRAPPER = """<div class="scroll entities" style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem; white-space:pre-wrap">{}</div>"""
st.markdown(HTML_WRAPPER.format(text_to_analyze), unsafe_allow_html=True)
except:
text_to_analyze = selected_text
# Initialize Spark and create pipeline
spark = init_spark()
pipeline = create_pipeline(zeroShotLables)
output = fit_data(pipeline, text_to_analyze)
# Display matched sentence
st.subheader("Prediction:")
st.markdown(f"Document Classified as: **{output[0]['class'][0].result}**")