ajitrajasekharan commited on
Commit
a1dc9c7
·
1 Parent(s): 562b4f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -120,20 +120,26 @@ def perform_inference(text,display_area):
120
  display_area.text("Initializing BERT module...")
121
  st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")
122
 
123
-
124
-
 
125
 
126
 
127
  display_area.text("Getting predictions from BERT model...")
128
  phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
129
  display_area.text("Computing NER results...")
130
 
 
131
  phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)
 
 
 
 
132
 
133
- return phi_ner
134
 
135
 
136
  sent_arr = [
 
137
  "John Doe flew from New York to Rio De Janiro ",
138
  "In 2020, John participated in the Winter Olympics and came third in Ice hockey",
139
  "Stanford called",
@@ -149,6 +155,7 @@ sent_arr = [
149
 
150
 
151
  sent_arr_masked = [
 
152
  "John:__entity__ Doe:__entity__ flew from New:__entity__ York:__entity__ to Rio:__entity__ De:__entity__ Janiro:__entity__ ",
153
  "In 2020:__entity__, Catherine:__entity__ Zeta:__entity__ Jones:__entity__ participated in the Winter:__entity__ Olympics:__entity__ and came third in Ice:__entity__ hockey:__entity__",
154
  "Stanford:__entity__ called",
 
120
  display_area.text("Initializing BERT module...")
121
  st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")
122
 
123
+ if (st.session_state['aggr'] is None):
124
+ display_area.text("Initializing Aggregation modeule...")
125
+ st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json")
126
 
127
 
128
  display_area.text("Getting predictions from BERT model...")
129
  phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
130
  display_area.text("Computing NER results...")
131
 
132
+ display_area.text("Consolidating responses...")
133
  phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)
134
+ obj = phi_ner
135
+ combined_arr = [obj,obj]
136
+ aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr)
137
+ return aggregate_results
138
 
 
139
 
140
 
141
  sent_arr = [
142
+ "Washington who resigned from Washington flew to Washington",
143
  "John Doe flew from New York to Rio De Janiro ",
144
  "In 2020, John participated in the Winter Olympics and came third in Ice hockey",
145
  "Stanford called",
 
155
 
156
 
157
  sent_arr_masked = [
158
+ "Washington:__entity__ who resigned from Washington:__entity__ flew to Washington:__entity__",
159
  "John:__entity__ Doe:__entity__ flew from New:__entity__ York:__entity__ to Rio:__entity__ De:__entity__ Janiro:__entity__ ",
160
  "In 2020:__entity__, Catherine:__entity__ Zeta:__entity__ Jones:__entity__ participated in the Winter:__entity__ Olympics:__entity__ and came third in Ice:__entity__ hockey:__entity__",
161
  "Stanford:__entity__ called",