Wootang01's picture
Update app.py
46170b1
raw
history blame
2.92 kB
import time
import streamlit as st
from annotated_text import annotated_text
from flair.data import Sentence
from flair.models import SequenceTagger
checkpoints = [
"flair/pos-english",
]
colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
return SequenceTagger.load(model_name) # Load the model
def getPos(s: Sentence):
texts = []
labels = []
for t in s.tokens:
for label in t.annotation_layers.keys():
texts.append(t.text)
labels.append(t.get_labels(label)[0].value)
return texts, labels
def getDictFromPOS(texts, labels):
return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
def getAnnotatedFromPOS(texts, labels):
return [(t,l,colors[l]) for t, l in zip(texts, labels)]
def main():
st.title("Part of Speech Categorizer")
st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.")
checkpoint = st.selectbox("Choose model", checkpoints)
model = get_model(checkpoint)
default_text = "Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods"
input_text = st.text_area(
label="Original text",
value=default_text,
)
start = None
if st.button("Submit"):
start = time.time()
with st.spinner("Computing"):
# Build Sentence
s = Sentence(input_text)
# predict tags
model.predict(s)
try:
texts, labels = getPos(s)
st.header("Labels:")
anns = getAnnotatedFromPOS(texts, labels)
annotated_text(*anns)
st.header("JSON:")
st.json(getDictFromPOS(texts, labels))
except Exception as e:
st.error("Some error occured!" + str(e))
st.stop()
st.write("---")
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
if __name__ == "__main__":
main()