asaduzzaman607 commited on
Commit
887060c
·
1 Parent(s): 999de44

Add application file

Browse files
Files changed (1) hide show
  1. app.py +354 -0
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ import pickle
4
+ from gradio import Progress
5
+ import numpy as np
6
+ import subprocess
7
+ import shutil
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.metrics import roc_curve, auc
10
+ import pandas as pd
11
+ # Define the function to process the input file and model selection
12
+
13
+ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
14
+ # progress = gr.Progress(track_tqdm=True)
15
+
16
+ progress(0, desc="Starting the processing")
17
+ # with open(file.name, 'r') as f:
18
+ # content = f.read()
19
+ # saved_test_dataset = "train.txt"
20
+ # saved_test_label = "train_label.txt"
21
+ # saved_train_info="train_info.txt"
22
+ # Save the uploaded file content to a specified location
23
+ # shutil.copyfile(file.name, saved_test_dataset)
24
+ # shutil.copyfile(label.name, saved_test_label)
25
+ # shutil.copyfile(info.name, saved_train_info)
26
+ parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
27
+ if(model_name=="High Graduated Schools"):
28
+ finetune_task="highGRschool10"
29
+ test_info_location=parent_location+"highGRschool10/test_info.txt"
30
+ test_location=parent_location+"highGRschool10/test.txt"
31
+ elif(model_name== "Low Graduated Schools" ):
32
+ finetune_task="lowGRschoolAll"
33
+ test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
34
+ test_location=parent_location+"lowGRschoolAll/test.txt"
35
+ elif(model_name=="Full Set"):
36
+ test_info_location=parent_location+"highGRschool10/test_info.txt"
37
+ test_location=parent_location+"highGRschool10/test.txt"
38
+ finetune_task="highGRschool10"
39
+ else:
40
+ finetune_task=None
41
+ # Load the test_info file and the graduation rate file
42
+ test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
43
+ grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
44
+
45
+ # Step 1: Extract unique school numbers from test_info
46
+ unique_schools = test_info[0].unique()
47
+
48
+ # Step 2: Filter the grad_rate_data using the unique school numbers
49
+ schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
50
+
51
+ # Define a threshold for high and low graduation rates (adjust as needed)
52
+ grad_rate_threshold = 0.9
53
+
54
+ # Step 4: Divide schools into high and low graduation rate groups
55
+ high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
56
+ low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()
57
+
58
+ # Step 5: Sample percentage of schools from each group
59
+ high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
60
+ low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
61
+
62
+ # Step 6: Combine the sampled schools
63
+ random_schools = high_sample + low_sample
64
+
65
+ # Step 7: Get indices for the sampled schools
66
+ indices = test_info[test_info[0].isin(random_schools)].index.tolist()
67
+
68
+ # Load the test file and select rows based on indices
69
+ test = pd.read_csv(test_location, sep=',', header=None, engine='python')
70
+ selected_rows_df2 = test.loc[indices]
71
+
72
+ # Save the selected rows to a file
73
+ selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
74
+
75
+
76
+ # For demonstration purposes, we'll just return the content with the selected model name
77
+
78
+ # print(checkpoint)
79
+ progress(0.1, desc="Files created and saved")
80
+ # if (inc_val<5):
81
+ # model_name="highGRschool10"
82
+ # elif(inc_val>=5 & inc_val<10):
83
+ # model_name="highGRschool10"
84
+ # else:
85
+ # model_name="highGRschool10"
86
+ progress(0.2, desc="Executing models")
87
+ subprocess.run([
88
+ "python", "new_test_saved_finetuned_model.py",
89
+ "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
90
+ "-finetune_task", finetune_task,
91
+ "-test_dataset_path","../../../../selected_rows.txt",
92
+ # "-test_label_path","../../../../train_label.txt",
93
+ "-finetuned_bert_classifier_checkpoint",
94
+ "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
95
+ "-e",str(1),
96
+ "-b",str(1000)
97
+ ])
98
+ progress(0.6,desc="Model execution completed")
99
+ result = {}
100
+ with open("result.txt", 'r') as file:
101
+ for line in file:
102
+ key, value = line.strip().split(': ', 1)
103
+ # print(type(key))
104
+ if key=='epoch':
105
+ result[key]=value
106
+ else:
107
+ result[key]=float(value)
108
+ # Create a plot
109
+ with open("roc_data.pkl", "rb") as f:
110
+ fpr, tpr, _ = pickle.load(f)
111
+
112
+ roc_auc = auc(fpr, tpr)
113
+ fig, ax = plt.subplots()
114
+ ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
115
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
116
+ ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}')
117
+ ax.legend(loc="lower right")
118
+ ax.grid()
119
+
120
+ # Save plot to a file
121
+ plot_path = "plot.png"
122
+ fig.savefig(plot_path)
123
+ plt.close(fig)
124
+ progress(1.0)
125
+ # Prepare text output
126
+ text_output = f"Model: {model_name}\nResult:\n{result}"
127
+ # Prepare text output with HTML formatting
128
+ text_output = f"""
129
+ Model: {model_name}\n
130
+ Result Summary:\n
131
+ -----------------\n
132
+ Precision: {result['precisions']:.2f}\n
133
+ Recall: {result['recalls']:.2f}\n
134
+ Time Taken: {result['time_taken_from_start']:.2f} seconds\n
135
+ Total Schools in test: {len(unique_schools):.4f}\n
136
+ Total Schools taken: {len(random_schools):.4f}\n
137
+ High grad schools: {len(high_sample):.4f}\n
138
+ Low grad schools: {len(low_sample):.4f}\n
139
+ -----------------\n
140
+ Note: The ROC Curve is also displayed for the evaluation.
141
+ """
142
+ return text_output,plot_path
143
+
144
+ # List of models for the dropdown menu
145
+
146
+ models = ["High Graduated Schools", "Low Graduated Schools", "Full Set"]
147
+
148
+ # Create the Gradio interface
149
+ with gr.Blocks(css="""
150
+ body {
151
+ background-color: #1e1e1e!important;
152
+ font-family: 'Arial', sans-serif;
153
+ color: #f5f5f5!important;;
154
+ }
155
+ .gradio-container {
156
+ max-width: 850px!important;
157
+ margin: 0 auto!important;;
158
+ padding: 20px!important;;
159
+ background-color: #292929!important;
160
+ border-radius: 10px;
161
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
162
+ }
163
+ .gradio-container-4-44-0 .prose h1 {
164
+ font-size: var(--text-xxl);
165
+ color: #ffffff!important;
166
+ }
167
+ #title {
168
+ color: white!important;
169
+ font-size: 2.3em;
170
+ font-weight: bold;
171
+ text-align: center!important;
172
+ margin-bottom: 20px;
173
+ }
174
+ .description {
175
+ text-align: center;
176
+ font-size: 1.1em;
177
+ color: #bfbfbf;
178
+ margin-bottom: 30px;
179
+ }
180
+ .file-box {
181
+ max-width: 180px;
182
+ padding: 5px;
183
+ background-color: #444!important;
184
+ border: 1px solid #666!important;
185
+ border-radius: 6px;
186
+ height: 80px!important;;
187
+ margin: 0 auto!important;;
188
+ text-align: center;
189
+ color: transparent;
190
+ }
191
+ .file-box span {
192
+ color: #f5f5f5!important;
193
+ font-size: 1em;
194
+ line-height: 45px; /* Vertically center text */
195
+ }
196
+ .dropdown-menu {
197
+ max-width: 220px;
198
+ margin: 0 auto!important;
199
+ background-color: #444!important;
200
+ color:#444!important;
201
+ border-radius: 6px;
202
+ padding: 8px;
203
+ font-size: 1.1em;
204
+ border: 1px solid #666;
205
+ }
206
+ .button {
207
+ background-color: #4CAF50!important;
208
+ color: white!important;
209
+ font-size: 1.1em;
210
+ padding: 10px 25px;
211
+ border-radius: 6px;
212
+ cursor: pointer;
213
+ transition: background-color 0.2s ease-in-out;
214
+ }
215
+ .button:hover {
216
+ background-color: #45a049!important;
217
+ }
218
+ .output-text {
219
+ background-color: #333!important;
220
+ padding: 12px;
221
+ border-radius: 8px;
222
+ border: 1px solid #666;
223
+ font-size: 1.1em;
224
+ }
225
+ .footer {
226
+ text-align: center;
227
+ margin-top: 50px;
228
+ font-size: 0.9em;
229
+ color: #b0b0b0;
230
+ }
231
+ .svelte-12ioyct .wrap {
232
+ display: none !important;
233
+ }
234
+ .file-label-text {
235
+ display: none !important;
236
+ }
237
+
238
+ div.svelte-sfqy0y {
239
+ display: flex;
240
+ flex-direction: inherit;
241
+ flex-wrap: wrap;
242
+ gap: var(--form-gap-width);
243
+ box-shadow: var(--block-shadow);
244
+ border: var(--block-border-width) solid var(--border-color-primary);
245
+ border-radius: var(--block-radius);
246
+ background: #1f2937!important;
247
+ overflow-y: hidden;
248
+ }
249
+
250
+ .block.svelte-12cmxck {
251
+ position: relative;
252
+ margin: 0;
253
+ box-shadow: var(--block-shadow);
254
+ border-width: var(--block-border-width);
255
+ border-color: var(--block-border-color);
256
+ border-radius: var(--block-radius);
257
+ background: #1f2937!important;
258
+ width: 100%;
259
+ line-height: var(--line-sm);
260
+ }
261
+
262
+ .svelte-12ioyct .wrap {
263
+ display: none !important;
264
+ }
265
+ .file-label-text {
266
+ display: none !important;
267
+ }
268
+ input[aria-label="file upload"] {
269
+ display: none !important;
270
+ }
271
+
272
+ gradio-app .gradio-container.gradio-container-4-44-0 .contain .file-box span {
273
+ font-size: 1em;
274
+ line-height: 45px;
275
+ color: #1f2937 !important;
276
+ }
277
+ .wrap.svelte-12ioyct {
278
+ display: flex;
279
+ flex-direction: column;
280
+ justify-content: center;
281
+ align-items: center;
282
+ min-height: var(--size-60);
283
+ color: #1f2937 !important;
284
+ line-height: var(--line-md);
285
+ height: 100%;
286
+ padding-top: var(--size-3);
287
+ text-align: center;
288
+ margin: auto var(--spacing-lg);
289
+ }
290
+ span.svelte-1gfkn6j:not(.has-info) {
291
+ margin-bottom: var(--spacing-lg);
292
+ color: white!important;
293
+ }
294
+ label.float.svelte-1b6s6s {
295
+ position: relative!important;
296
+ top: var(--block-label-margin);
297
+ left: var(--block-label-margin);
298
+ }
299
+ label.svelte-1b6s6s {
300
+ display: inline-flex;
301
+ align-items: center;
302
+ z-index: var(--layer-2);
303
+ box-shadow: var(--block-label-shadow);
304
+ border: var(--block-label-border-width) solid var(--border-color-primary);
305
+ border-top: none;
306
+ border-left: none;
307
+ border-radius: var(--block-label-radius);
308
+ background: rgb(120 151 180)!important;
309
+ padding: var(--block-label-padding);
310
+ pointer-events: none;
311
+ color: #1f2937!important;
312
+ font-weight: var(--block-label-text-weight);
313
+ font-size: var(--block-label-text-size);
314
+ line-height: var(--line-sm);
315
+ }
316
+ .file.svelte-18wv37q.svelte-18wv37q {
317
+ display: block!important;
318
+ width: var(--size-full);
319
+ }
320
+
321
+ tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
322
+ background: ##7897b4!important;
323
+ color: white;
324
+ background: #aca7b2;
325
+ }
326
+ .gradio-container-4-31-4 .prose h1, .gradio-container-4-31-4 .prose h2, .gradio-container-4-31-4 .prose h3, .gradio-container-4-31-4 .prose h4, .gradio-container-4-31-4 .prose h5 {
327
+
328
+ color: white;
329
+ """) as demo:
330
+ gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
331
+ gr.Markdown("<p class='description'>Upload a .txt file and select a model from the dropdown menu.</p>")
332
+
333
+ with gr.Row():
334
+ # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
335
+ # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
336
+
337
+ # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
338
+
339
+ model_dropdown = gr.Dropdown(choices=models, label="Select Finetune Task", elem_classes="dropdown-menu")
340
+
341
+
342
+ increment_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Schools Percentage", value=1)
343
+
344
+ with gr.Row():
345
+ output_text = gr.Textbox(label="Output Text")
346
+ output_image = gr.Image(label="Output Plot")
347
+
348
+ btn = gr.Button("Submit")
349
+
350
+ btn.click(fn=process_file, inputs=[model_dropdown,increment_slider], outputs=[output_text,output_image])
351
+
352
+
353
+ # Launch the app
354
+ demo.launch()