import streamlit as st
import sparknlp
import os
import pandas as pd
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
from sparknlp.pretrained import PretrainedPipeline
from annotated_text import annotated_text
from streamlit_tags import st_tags
# Page configuration
st.set_page_config(
layout="wide",
initial_sidebar_state="auto"
)
# CSS for styling
st.markdown("""
""", unsafe_allow_html=True)
@st.cache_resource
def init_spark():
return sparknlp.start()
@st.cache_resource
def create_pipeline(zeroShotLables=['']):
document_assembler = DocumentAssembler() \
.setInputCol('text') \
.setOutputCol('document')
tokenizer = Tokenizer() \
.setInputCols(['document']) \
.setOutputCol('token')
zeroShotClassifier = XlmRoBertaForZeroShotClassification \
.pretrained('xlm_roberta_large_zero_shot_classifier_xnli_anli', 'xx') \
.setInputCols(['token', 'document']) \
.setOutputCol('class') \
.setCaseSensitive(False) \
.setMaxSentenceLength(512) \
.setCandidateLabels(zeroShotLables)
pipeline = Pipeline(stages=[document_assembler, tokenizer, zeroShotClassifier])
return pipeline
def fit_data(pipeline, data):
empty_df = spark.createDataFrame([['']]).toDF('text')
pipeline_model = pipeline.fit(empty_df)
model = LightPipeline(pipeline_model)
result = model.fullAnnotate(data)
return result
def annotate(data):
document, chunks, labels = data["Document"], data["NER Chunk"], data["NER Label"]
annotated_words = []
for chunk, label in zip(chunks, labels):
parts = document.split(chunk, 1)
if parts[0]:
annotated_words.append(parts[0])
annotated_words.append((chunk, label))
document = parts[1]
if document:
annotated_words.append(document)
annotated_text(*annotated_words)
tasks_models_descriptions = {
"Zero-Shot Classification": {
"models": ["xlm_roberta_large_zero_shot_classifier_xnli_anli"],
"description": "The 'xlm_roberta_large_zero_shot_classifier_xnli_anli' model provides flexible text classification without needing training data for specific categories. It is ideal for dynamic scenarios where text needs to be categorized into topics like urgent issues, technology, or sports without prior labeling."
}
}
# Sidebar content
task = st.sidebar.selectbox("Choose the task", list(tasks_models_descriptions.keys()))
model = st.sidebar.selectbox("Choose the pretrained model", tasks_models_descriptions[task]["models"], help="For more info about the models visit: https://sparknlp.org/models")
# Reference notebook link in sidebar
link = """
"""
st.sidebar.markdown('Reference notebook:')
st.sidebar.markdown(link, unsafe_allow_html=True)
# Page content
title, sub_title = (f'DeBERTa for {task}', tasks_models_descriptions[task]["description"])
st.markdown(f'