File size: 2,921 Bytes
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1bcc63
46170b1
5c1dd23
 
 
a89e61f
5c1dd23
 
 
 
 
 
e1bcc63
5c1dd23
e1bcc63
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import time
import streamlit as st
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
   "flair/pos-english",
]

colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
  return SequenceTagger.load(model_name) # Load the model
  
def getPos(s: Sentence):
  texts = []
  labels = []
  for t in s.tokens:
    for label in t.annotation_layers.keys():
      texts.append(t.text)
      labels.append(t.get_labels(label)[0].value)          
  return texts, labels
  
def getDictFromPOS(texts, labels):
  return [{ "text": t, "label": l } for t, l in zip(texts, labels)]

def getAnnotatedFromPOS(texts, labels):
  return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

  st.title("Part of Speech Categorizer")
  st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.")
  checkpoint = st.selectbox("Choose model", checkpoints)
  model = get_model(checkpoint)
 
  default_text = "Please note that although the machine can read apostrophes it cannot read other punctuation marks such as commas or periods"
  input_text = st.text_area(
    label="Original text",
    value=default_text,
  )

  start = None
  if st.button("Submit"):
    start = time.time()
    with st.spinner("Computing"):
    

            # Build Sentence
            s = Sentence(input_text)
	
            # predict tags
            model.predict(s)

            try:
	
                texts, labels = getPos(s)
	                
                st.header("Labels:")
                anns = getAnnotatedFromPOS(texts, labels)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Some error occured!" + str(e))
                st.stop()
	
    st.write("---")


	
    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")
	
	
if __name__ == "__main__":
    main()