ashhadahsan commited on
Commit
ca3e8cf
·
1 Parent(s): f821f11

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -340
app.py DELETED
@@ -1,340 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- from transformers import pipeline
4
- from stqdm import stqdm
5
- from simplet5 import SimpleT5
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
- from transformers import BertTokenizer, TFBertForSequenceClassification
8
- from tensorflow.keras.models import load_model
9
- from tensorflow.nn import softmax
10
- import numpy as np
11
- from datetime import datetime
12
- import logging
13
- from constants import sub_themes_dict
14
-
15
- date = datetime.now().strftime(r"%Y-%m-%d")
16
- model_classes = {
17
- 0: "Ads",
18
- 1: "Apps",
19
- 2: "Battery",
20
- 3: "Charging",
21
- 4: "Delivery",
22
- 5: "Display",
23
- 6: "FOS",
24
- 7: "HW",
25
- 8: "Order",
26
- 9: "Refurb",
27
- 10: "SD",
28
- 11: "Setup",
29
- 12: "Unknown",
30
- 13: "WiFi",
31
- }
32
-
33
-
34
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
35
- # @st.cache_resource
36
- def load_t5():
37
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
38
-
39
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
40
- return model, tokenizer
41
-
42
-
43
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
44
- # @st.cache_resource
45
- def custom_model():
46
- return pipeline("summarization", model="my_awesome_sum/")
47
-
48
-
49
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
50
- # @st.cache_resource
51
- def convert_df(df):
52
- # IMPORTANT: Cache the conversion to prevent computation on every rerun
53
- return df.to_csv(index=False).encode("utf-8")
54
-
55
-
56
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
57
- # @st.cache_resource
58
- def load_one_line_summarizer(model):
59
- return model.load_model("t5", "snrspeaks/t5-one-line-summary")
60
-
61
-
62
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
63
- # @st.cache_resource
64
- def classify_category():
65
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
66
- new_model = load_model("model")
67
- return tokenizer, new_model
68
-
69
-
70
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
71
- # @st.cache_resource
72
- def classify_sub_theme():
73
- tokenizer = BertTokenizer.from_pretrained(
74
- "ashhadahsan/amazon-subtheme-bert-base-finetuned"
75
- )
76
- new_model = TFBertForSequenceClassification.from_pretrained(
77
- "ashhadahsan/amazon-subtheme-bert-base-finetuned"
78
- )
79
- return tokenizer, new_model
80
-
81
-
82
- st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
83
- st.title("Amazon Review Summarizer")
84
-
85
- uploaded_file = st.file_uploader("Choose a file", type=["xlsx", "xls", "csv"])
86
- summarizer_option = st.selectbox(
87
- "Select Summarizer",
88
- ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
89
- )
90
- col1, col2, col3 = st.columns([1, 1, 1])
91
-
92
- with col1:
93
- summary_yes = st.checkbox("Summrization", value=False)
94
-
95
- with col2:
96
- classification = st.checkbox("Classify Category", value=True)
97
-
98
- with col3:
99
- sub_theme = st.checkbox("Sub theme classification", value=True)
100
-
101
- ps = st.empty()
102
-
103
- if st.button("Process", type="primary"):
104
- cancel_button = st.empty()
105
- cancel_button2 = st.empty()
106
- cancel_button3 = st.empty()
107
- if uploaded_file is not None:
108
- if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
109
- df = pd.read_excel(uploaded_file, engine="openpyxl")
110
- if uploaded_file.name.split(".")[-1] in [".csv"]:
111
- df = pd.read_csv(uploaded_file)
112
- columns = df.columns.values.tolist()
113
- columns = [x.lower() for x in columns]
114
- df.columns = columns
115
- print(summarizer_option)
116
- output = pd.DataFrame()
117
- try:
118
- text = df["text"].values.tolist()
119
- output["text"] = text
120
- if summarizer_option == "Custom trained on the dataset":
121
- if summary_yes:
122
- model = custom_model()
123
-
124
- progress_text = "Summarization in progress. Please wait."
125
- summary = []
126
-
127
- for x in stqdm(range(len(text))):
128
- if cancel_button.button("Cancel", key=x):
129
- del model
130
- break
131
- try:
132
- summary.append(
133
- model(
134
- f"summarize: {text[x]}",
135
- max_length=50,
136
- early_stopping=True,
137
- )[0]["summary_text"]
138
- )
139
- except:
140
- pass
141
- output["summary"] = summary
142
- del model
143
- if classification:
144
- classification_token, classification_model = classify_category()
145
- tf_batch = classification_token(
146
- text,
147
- max_length=128,
148
- padding=True,
149
- truncation=True,
150
- return_tensors="tf",
151
- )
152
- with st.spinner(text="identifying theme"):
153
- tf_outputs = classification_model(tf_batch)
154
- classes = []
155
- with st.spinner(text="creating output file"):
156
- for x in stqdm(range(len(text))):
157
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
158
- label = np.argmax(tf_o, axis=0)
159
- keys = model_classes
160
- classes.append(keys.get(label))
161
- output["category"] = classes
162
- del classification_token, classification_model
163
- if sub_theme:
164
- classification_token, classification_model = classify_sub_theme()
165
- tf_batch = classification_token(
166
- text,
167
- max_length=128,
168
- padding=True,
169
- truncation=True,
170
- return_tensors="tf",
171
- )
172
- with st.spinner(text="identifying sub theme"):
173
- tf_outputs = classification_model(tf_batch)
174
- classes = []
175
- with st.spinner(text="creating output file"):
176
- for x in stqdm(range(len(text))):
177
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
178
- label = np.argmax(tf_o, axis=0)
179
- keys = sub_themes_dict
180
- classes.append(keys.get(label))
181
- output["sub theme"] = classes
182
- del classification_token, classification_model
183
-
184
- csv = convert_df(output)
185
- st.download_button(
186
- label="Download data as CSV",
187
- data=csv,
188
- file_name=f"{summarizer_option}_{date}_df.csv",
189
- mime="text/csv",
190
- )
191
- if summarizer_option == "t5-base":
192
- if summary_yes:
193
- model, tokenizer = load_t5()
194
- summary = []
195
- for x in stqdm(range(len(text))):
196
- if cancel_button2.button("Cancel", key=x):
197
- del model, tokenizer
198
- break
199
- tokens_input = tokenizer.encode(
200
- "summarize: " + text[x],
201
- return_tensors="pt",
202
- max_length=tokenizer.model_max_length,
203
- truncation=True,
204
- )
205
- summary_ids = model.generate(
206
- tokens_input,
207
- min_length=80,
208
- max_length=150,
209
- length_penalty=20,
210
- num_beams=2,
211
- )
212
- summary_gen = tokenizer.decode(
213
- summary_ids[0], skip_special_tokens=True
214
- )
215
- summary.append(summary_gen)
216
- del model, tokenizer
217
- output["summary"] = summary
218
-
219
- if classification:
220
- classification_token, classification_model = classify_category()
221
- tf_batch = classification_token(
222
- text,
223
- max_length=128,
224
- padding=True,
225
- truncation=True,
226
- return_tensors="tf",
227
- )
228
- with st.spinner(text="identifying theme"):
229
- tf_outputs = classification_model(tf_batch)
230
- classes = []
231
- with st.spinner(text="creating output file"):
232
- for x in stqdm(range(len(text))):
233
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
234
- label = np.argmax(tf_o, axis=0)
235
- keys = model_classes
236
- classes.append(keys.get(label))
237
- output["category"] = classes
238
- del classification_token, classification_model
239
- if sub_theme:
240
- classification_token, classification_model = classify_sub_theme()
241
- tf_batch = classification_token(
242
- text,
243
- max_length=128,
244
- padding=True,
245
- truncation=True,
246
- return_tensors="tf",
247
- )
248
- with st.spinner(text="identifying sub theme"):
249
- tf_outputs = classification_model(tf_batch)
250
- classes = []
251
- with st.spinner(text="creating output file"):
252
- for x in stqdm(range(len(text))):
253
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
254
- label = np.argmax(tf_o, axis=0)
255
- keys = sub_themes_dict
256
- classes.append(keys.get(label))
257
- output["sub theme"] = classes
258
- del classification_token, classification_model
259
- csv = convert_df(output)
260
- st.download_button(
261
- label="Download data as CSV",
262
- data=csv,
263
- file_name=f"{summarizer_option}_{date}_df.csv",
264
- mime="text/csv",
265
- )
266
-
267
- if summarizer_option == "t5-one-line-summary":
268
- if summary_yes:
269
- model = SimpleT5()
270
- load_one_line_summarizer(model=model)
271
-
272
- summary = []
273
- for x in stqdm(range(len(text))):
274
- if cancel_button3.button("Cancel", key=x):
275
- del model
276
- break
277
- try:
278
- summary.append(model.predict(text[x])[0])
279
- except:
280
- pass
281
- output["summary"] = summary
282
- del model
283
-
284
- if classification:
285
- classification_token, classification_model = classify_category()
286
- tf_batch = classification_token(
287
- text,
288
- max_length=128,
289
- padding=True,
290
- truncation=True,
291
- return_tensors="tf",
292
- )
293
- with st.spinner(text="identifying theme"):
294
- tf_outputs = classification_model(tf_batch)
295
- classes = []
296
- with st.spinner(text="creating output file"):
297
- for x in stqdm(range(len(text))):
298
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
299
- label = np.argmax(tf_o, axis=0)
300
- keys = model_classes
301
- classes.append(keys.get(label))
302
- output["category"] = classes
303
- del classification_token, classification_model
304
- if sub_theme:
305
- classification_token, classification_model = classify_sub_theme()
306
- tf_batch = classification_token(
307
- text,
308
- max_length=128,
309
- padding=True,
310
- truncation=True,
311
- return_tensors="tf",
312
- )
313
- with st.spinner(text="identifying sub theme"):
314
- tf_outputs = classification_model(tf_batch)
315
- classes = []
316
- with st.spinner(text="creating output file"):
317
- for x in stqdm(range(len(text))):
318
- tf_o = softmax(tf_outputs["logits"][x], axis=-1)
319
- label = np.argmax(tf_o, axis=0)
320
- keys = sub_themes_dict
321
- classes.append(keys.get(label))
322
- output["sub theme"] = classes
323
- del classification_token, classification_model
324
-
325
- csv = convert_df(output)
326
- st.download_button(
327
- label="Download data as CSV",
328
- data=csv,
329
- file_name=f"{summarizer_option}_{date}_df.csv",
330
- mime="text/csv",
331
- )
332
-
333
- except KeyError:
334
- st.error(
335
- "Please Make sure that your data must have a column named text",
336
- icon="🚨",
337
- )
338
- st.info("Text column must have amazon reviews", icon="ℹ️")
339
- except BaseException as e:
340
- logging.exception("An exception was occurred")