ashhadahsan commited on
Commit
3d51c8f
·
1 Parent(s): 5521456

added sub theme

Browse files
Files changed (1) hide show
  1. app.py +162 -80
app.py CHANGED
@@ -13,9 +13,10 @@ import logging
13
  import pip
14
 
15
 
 
16
 
17
  date = datetime.now().strftime(r"%Y-%m-%d")
18
- model_classes ={
19
  0: "Ads",
20
  1: "Apps",
21
  2: "Battery",
@@ -32,7 +33,8 @@ model_classes ={
32
  13: "WiFi",
33
  }
34
 
35
- @st.cache(allow_output_mutation=True,suppress_st_warning=True)
 
36
  def load_t5():
37
  model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
38
 
@@ -40,29 +42,36 @@ def load_t5():
40
  return model, tokenizer
41
 
42
 
43
- @st.cache(allow_output_mutation=True,suppress_st_warning=True)
44
  def custom_model():
45
  return pipeline("summarization", model="my_awesome_sum/")
46
 
47
 
48
- @st.cache(allow_output_mutation=True,suppress_st_warning=True)
49
  def convert_df(df):
50
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
51
  return df.to_csv(index=False).encode("utf-8")
52
 
53
 
54
- @st.cache(allow_output_mutation=True,suppress_st_warning=True)
55
  def load_one_line_summarizer(model):
56
  return model.load_model("t5", "snrspeaks/t5-one-line-summary")
57
 
58
 
59
- @st.cache(allow_output_mutation=True,suppress_st_warning=True)
60
  def classify_category():
61
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
62
  new_model = load_model("model")
63
  return tokenizer, new_model
64
 
65
 
 
 
 
 
 
 
 
66
  st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
67
  st.title("Amazon Review Summarizer")
68
 
@@ -71,14 +80,23 @@ summarizer_option = st.selectbox(
71
  "Select Summarizer",
72
  ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
73
  )
74
- classification = st.checkbox("Classify Category", value=True)
 
 
 
 
 
 
 
 
 
75
 
76
  ps = st.empty()
77
-
78
- if st.button("Process",type="primary"):
79
- cancel_button=st.empty()
80
- cancel_button2=st.empty()
81
- cancel_button3=st.empty()
82
  if uploaded_file is not None:
83
  if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
84
 
@@ -89,33 +107,34 @@ if st.button("Process",type="primary"):
89
  columns = [x.lower() for x in columns]
90
  df.columns = columns
91
  print(summarizer_option)
92
-
93
  try:
94
  text = df["text"].values.tolist()
 
95
  if summarizer_option == "Custom trained on the dataset":
96
- model = custom_model()
97
-
98
- progress_text = "Summarization in progress. Please wait."
99
- summary = []
100
-
101
- for x in stqdm(range(len(text))):
102
-
103
- if cancel_button.button("Cancel",key=x):
104
- del model
105
- break
106
- try:
107
- summary.append(
108
- model(
109
- f"summarize: {text[x]}",
110
- max_length=50,
111
- early_stopping=True,
112
- )[0]["summary_text"]
113
- )
114
- except:
115
- pass
116
- output = pd.DataFrame(
117
- {"text": df["text"].values.tolist(), "summary": summary}
118
- )
119
  if classification:
120
  classification_token, classification_model = classify_category()
121
  tf_batch = classification_token(
@@ -135,6 +154,27 @@ if st.button("Process",type="primary"):
135
  keys = model_classes
136
  classes.append(keys.get(label))
137
  output["category"] = classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  csv = convert_df(output)
140
  st.download_button(
@@ -144,34 +184,34 @@ if st.button("Process",type="primary"):
144
  mime="text/csv",
145
  )
146
  if summarizer_option == "t5-base":
147
- model, tokenizer = load_t5()
148
- summary = []
149
- for x in stqdm(range(len(text))):
150
-
151
- if cancel_button2.button("Cancel",key=x):
152
- del model,tokenizer
153
- break
154
- tokens_input = tokenizer.encode(
155
- "summarize: " + text[x],
156
- return_tensors="pt",
157
- max_length=tokenizer.model_max_length,
158
- truncation=True,
159
- )
160
- summary_ids = model.generate(
161
- tokens_input,
162
- min_length=80,
163
- max_length=150,
164
- length_penalty=20,
165
- num_beams=2,
166
- )
167
- summary_gen = tokenizer.decode(
168
- summary_ids[0], skip_special_tokens=True
169
- )
170
- summary.append(summary_gen)
 
 
 
171
 
172
- output = pd.DataFrame(
173
- {"text": df["text"].values.tolist(), "summary": summary}
174
- )
175
  if classification:
176
  classification_token, classification_model = classify_category()
177
  tf_batch = classification_token(
@@ -191,7 +231,27 @@ if st.button("Process",type="primary"):
191
  keys = model_classes
192
  classes.append(keys.get(label))
193
  output["category"] = classes
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  csv = convert_df(output)
196
  st.download_button(
197
  label="Download data as CSV",
@@ -201,21 +261,22 @@ if st.button("Process",type="primary"):
201
  )
202
 
203
  if summarizer_option == "t5-one-line-summary":
204
- model = SimpleT5()
205
- load_one_line_summarizer(model=model)
206
-
207
- summary = []
208
- for x in stqdm(range(len(text))):
209
- if cancel_button3.button("Cancel",key=x):
210
- del model
211
- break
212
- try:
213
- summary.append(model.predict(text[x])[0])
214
- except:
215
- pass
216
- output = pd.DataFrame(
217
- {"text": df["text"].values.tolist(), "summary": summary}
218
- )
 
219
  if classification:
220
  classification_token, classification_model = classify_category()
221
  tf_batch = classification_token(
@@ -235,6 +296,27 @@ if st.button("Process",type="primary"):
235
  keys = model_classes
236
  classes.append(keys.get(label))
237
  output["category"] = classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  csv = convert_df(output)
240
  st.download_button(
@@ -251,4 +333,4 @@ if st.button("Process",type="primary"):
251
  )
252
  st.info("Text column must have amazon reviews", icon="ℹ️")
253
  except BaseException as e:
254
- logging.exception("An exception was occurred")
 
13
  import pip
14
 
15
 
16
+ import gc
17
 
18
  date = datetime.now().strftime(r"%Y-%m-%d")
19
+ model_classes = {
20
  0: "Ads",
21
  1: "Apps",
22
  2: "Battery",
 
33
  13: "WiFi",
34
  }
35
 
36
+
37
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
38
  def load_t5():
39
  model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
40
 
 
42
  return model, tokenizer
43
 
44
 
45
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
46
  def custom_model():
47
  return pipeline("summarization", model="my_awesome_sum/")
48
 
49
 
50
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
51
  def convert_df(df):
52
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
53
  return df.to_csv(index=False).encode("utf-8")
54
 
55
 
56
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
57
  def load_one_line_summarizer(model):
58
  return model.load_model("t5", "snrspeaks/t5-one-line-summary")
59
 
60
 
61
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
62
  def classify_category():
63
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
64
  new_model = load_model("model")
65
  return tokenizer, new_model
66
 
67
 
68
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
69
+ def classify_sub_theme():
70
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
71
+ new_model = load_model("sub_theme_model")
72
+ return tokenizer, new_model
73
+
74
+
75
  st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
76
  st.title("Amazon Review Summarizer")
77
 
 
80
  "Select Summarizer",
81
  ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
82
  )
83
+ col1, col2, col3 = st.columns([1, 1, 1])
84
+
85
+ with col1:
86
+ summary_yes = st.checkbox("Summrization", value=False)
87
+
88
+ with col2:
89
+ classification = st.checkbox("Classify Category", value=True)
90
+
91
+ with col3:
92
+ sub_theme = st.checkbox("Sub theme classification", value=True)
93
 
94
  ps = st.empty()
95
+
96
+ if st.button("Process", type="primary"):
97
+ cancel_button = st.empty()
98
+ cancel_button2 = st.empty()
99
+ cancel_button3 = st.empty()
100
  if uploaded_file is not None:
101
  if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
102
 
 
107
  columns = [x.lower() for x in columns]
108
  df.columns = columns
109
  print(summarizer_option)
110
+ output = pd.DataFrame()
111
  try:
112
  text = df["text"].values.tolist()
113
+ output["text"] = text
114
  if summarizer_option == "Custom trained on the dataset":
115
+ if summary_yes:
116
+ model = custom_model()
117
+
118
+ progress_text = "Summarization in progress. Please wait."
119
+ summary = []
120
+
121
+ for x in stqdm(range(len(text))):
122
+
123
+ if cancel_button.button("Cancel", key=x):
124
+ del model
125
+ break
126
+ try:
127
+ summary.append(
128
+ model(
129
+ f"summarize: {text[x]}",
130
+ max_length=50,
131
+ early_stopping=True,
132
+ )[0]["summary_text"]
133
+ )
134
+ except:
135
+ pass
136
+ output["summary"] = summary
137
+ del model
138
  if classification:
139
  classification_token, classification_model = classify_category()
140
  tf_batch = classification_token(
 
154
  keys = model_classes
155
  classes.append(keys.get(label))
156
  output["category"] = classes
157
+ del classification_token, classification_model
158
+ if sub_theme:
159
+ classification_token, classification_model = classify_sub_theme()
160
+ tf_batch = classification_token(
161
+ text,
162
+ max_length=128,
163
+ padding=True,
164
+ truncation=True,
165
+ return_tensors="tf",
166
+ )
167
+ with st.spinner(text="identifying sub theme"):
168
+ tf_outputs = classification_model(tf_batch)
169
+ classes = []
170
+ with st.spinner(text="creating output file"):
171
+ for x in stqdm(range(len(text))):
172
+ tf_o = softmax(tf_outputs["logits"][x], axis=-1)
173
+ label = np.argmax(tf_o, axis=0)
174
+ keys = model_classes
175
+ classes.append(keys.get(label))
176
+ output["sub theme"] = classes
177
+ del classification_token, classification_model
178
 
179
  csv = convert_df(output)
180
  st.download_button(
 
184
  mime="text/csv",
185
  )
186
  if summarizer_option == "t5-base":
187
+ if summary_yes:
188
+ model, tokenizer = load_t5()
189
+ summary = []
190
+ for x in stqdm(range(len(text))):
191
+
192
+ if cancel_button2.button("Cancel", key=x):
193
+ del model, tokenizer
194
+ break
195
+ tokens_input = tokenizer.encode(
196
+ "summarize: " + text[x],
197
+ return_tensors="pt",
198
+ max_length=tokenizer.model_max_length,
199
+ truncation=True,
200
+ )
201
+ summary_ids = model.generate(
202
+ tokens_input,
203
+ min_length=80,
204
+ max_length=150,
205
+ length_penalty=20,
206
+ num_beams=2,
207
+ )
208
+ summary_gen = tokenizer.decode(
209
+ summary_ids[0], skip_special_tokens=True
210
+ )
211
+ summary.append(summary_gen)
212
+ del model, tokenizer
213
+ output["summary"] = summary
214
 
 
 
 
215
  if classification:
216
  classification_token, classification_model = classify_category()
217
  tf_batch = classification_token(
 
231
  keys = model_classes
232
  classes.append(keys.get(label))
233
  output["category"] = classes
234
+ del classification_token, classification_model
235
+ if sub_theme:
236
+ classification_token, classification_model = classify_sub_theme()
237
+ tf_batch = classification_token(
238
+ text,
239
+ max_length=128,
240
+ padding=True,
241
+ truncation=True,
242
+ return_tensors="tf",
243
+ )
244
+ with st.spinner(text="identifying sub theme"):
245
+ tf_outputs = classification_model(tf_batch)
246
+ classes = []
247
+ with st.spinner(text="creating output file"):
248
+ for x in stqdm(range(len(text))):
249
+ tf_o = softmax(tf_outputs["logits"][x], axis=-1)
250
+ label = np.argmax(tf_o, axis=0)
251
+ keys = model_classes
252
+ classes.append(keys.get(label))
253
+ output["sub theme"] = classes
254
+ del classification_token, classification_model
255
  csv = convert_df(output)
256
  st.download_button(
257
  label="Download data as CSV",
 
261
  )
262
 
263
  if summarizer_option == "t5-one-line-summary":
264
+ if summary_yes:
265
+ model = SimpleT5()
266
+ load_one_line_summarizer(model=model)
267
+
268
+ summary = []
269
+ for x in stqdm(range(len(text))):
270
+ if cancel_button3.button("Cancel", key=x):
271
+ del model
272
+ break
273
+ try:
274
+ summary.append(model.predict(text[x])[0])
275
+ except:
276
+ pass
277
+ output["summary"] = summary
278
+ del model
279
+
280
  if classification:
281
  classification_token, classification_model = classify_category()
282
  tf_batch = classification_token(
 
296
  keys = model_classes
297
  classes.append(keys.get(label))
298
  output["category"] = classes
299
+ del classification_token,classification_model
300
+ if sub_theme:
301
+ classification_token, classification_model = classify_sub_theme()
302
+ tf_batch = classification_token(
303
+ text,
304
+ max_length=128,
305
+ padding=True,
306
+ truncation=True,
307
+ return_tensors="tf",
308
+ )
309
+ with st.spinner(text="identifying sub theme"):
310
+ tf_outputs = classification_model(tf_batch)
311
+ classes = []
312
+ with st.spinner(text="creating output file"):
313
+ for x in stqdm(range(len(text))):
314
+ tf_o = softmax(tf_outputs["logits"][x], axis=-1)
315
+ label = np.argmax(tf_o, axis=0)
316
+ keys = model_classes
317
+ classes.append(keys.get(label))
318
+ output["sub theme"] = classes
319
+ del classification_token, classification_model
320
 
321
  csv = convert_df(output)
322
  st.download_button(
 
333
  )
334
  st.info("Text column must have amazon reviews", icon="ℹ️")
335
  except BaseException as e:
336
+ logging.exception("An exception was occurred")