Spaces:
Runtime error
Runtime error
import time | |
import streamlit as st | |
from annotated_text import annotated_text | |
from flair.data import Sentence | |
from flair.models import SequenceTagger | |
checkpoints = [ | |
"flair/pos-english", | |
] | |
colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'} | |
def get_model(model_name): | |
return SequenceTagger.load(model_name) # Load the model | |
def getPos(s: Sentence): | |
texts = [] | |
labels = [] | |
for t in s.tokens: | |
for label in t.annotation_layers.keys(): | |
texts.append(t.text) | |
labels.append(t.get_labels(label)[0].value) | |
return texts, labels | |
def getDictFromPOS(texts, labels): | |
return [{ "text": t, "label": l } for t, l in zip(texts, labels)] | |
def getAnnotatedFromPOS(texts, labels): | |
return [(t,l,colors[l]) for t, l in zip(texts, labels)] | |
def main(): | |
st.title("Part of Speech Categorizer") | |
st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.") | |
checkpoint = st.selectbox("Choose model", checkpoints) | |
model = get_model(checkpoint) | |
default_text = "Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods" | |
input_text = st.text_area( | |
label="Original text", | |
value=default_text, | |
) | |
start = None | |
if st.button("Submit"): | |
start = time.time() | |
with st.spinner("Computing"): | |
# Build Sentence | |
s = Sentence(input_text) | |
# predict tags | |
model.predict(s) | |
try: | |
texts, labels = getPos(s) | |
st.header("Labels:") | |
anns = getAnnotatedFromPOS(texts, labels) | |
annotated_text(*anns) | |
st.header("JSON:") | |
st.json(getDictFromPOS(texts, labels)) | |
except Exception as e: | |
st.error("Some error occured!" + str(e)) | |
st.stop() | |
st.write("---") | |
if start is not None: | |
st.text(f"prediction took {time.time() - start:.2f}s") | |
if __name__ == "__main__": | |
main() |