vaivskku commited on
Commit
0eab365
ยท
1 Parent(s): 1d46afa
Files changed (1) hide show
  1. app.py +935 -0
app.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor
3
+ from PIL import Image
4
+ import torch
5
+ import warnings
6
+ import re
7
+ import json
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from scipy import optimize
14
+ from typing import Optional
15
+ import dataclasses
16
+ import editdistance
17
+ import itertools
18
+ import sys
19
+ import time
20
+ import logging
21
+
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger()
24
+
25
+ warnings.filterwarnings('ignore')
26
+ MAX_PATCHES = 512
27
+ # Load the models and processor
28
+ #device = torch.device("cpu")
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Paths to the models
32
+ ko_deplot_model_path = './deplot_model_ver_kor_24.7.25_refinetuning_epoch1.bin'
33
+ aihub_deplot_model_path='./deplot_k.pt'
34
+ t5_model_path = './ke_t5.pt'
35
+
36
+ # Load first model ko-deplot
37
+ processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
38
+ model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
39
+ model1.load_state_dict(torch.load(ko_deplot_model_path, map_location=device))
40
+ model1.to(device)
41
+
42
+ # Load second model aihub-deplot
43
+ processor2 = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
44
+ model2 = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")
45
+ model2.load_state_dict(torch.load(aihub_deplot_model_path, map_location=device))
46
+
47
+
48
+ tokenizer = T5Tokenizer.from_pretrained("KETI-AIR/ke-t5-base")
49
+ t5_model = T5ForConditionalGeneration.from_pretrained("KETI-AIR/ke-t5-base")
50
+ t5_model.load_state_dict(torch.load(t5_model_path, map_location=device))
51
+
52
+ model2.to(device)
53
+ t5_model.to(device)
54
+
55
+
56
+ #ko-deplot ์ถ”๋ก ํ•จ์ˆ˜
57
+ # Function to format output
58
+ def format_output(prediction):
59
+ return prediction.replace('<0x0A>', '\n')
60
+
61
+ # First model prediction ko-deplot
62
+ def predict_model1(image):
63
+ images = [image]
64
+ inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
65
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
66
+
67
+ model1.eval()
68
+ with torch.no_grad():
69
+ predictions = model1.generate(**inputs, max_new_tokens=4096)
70
+ outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
71
+
72
+ formatted_output = format_output(outputs[0])
73
+ return formatted_output
74
+
75
+
76
+ def replace_unk(text):
77
+ # 1. '์ œ๋ชฉ:', '์œ ํ˜•:' ๊ธ€์ž ์•ž์— ์žˆ๋Š” <unk>๋Š” \n๋กœ ๋ฐ”๊ฟˆ
78
+ text = re.sub(r'<unk>(?=์ œ๋ชฉ:|์œ ํ˜•:)', '\n', text)
79
+ # 2. '์„ธ๋กœ ' ๋˜๋Š” '๊ฐ€๋กœ '์™€ '๋Œ€ํ˜•' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ ""๋กœ ๋ฐ”๊ฟˆ
80
+ text = re.sub(r'(?<=์„ธ๋กœ |๊ฐ€๋กœ )<unk>(?=๋Œ€ํ˜•)', '', text)
81
+ # 3. ์ˆซ์ž์™€ ํ…์ŠคํŠธ ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
82
+ text = re.sub(r'(\d)<unk>([^\d])', r'\1\n\2', text)
83
+ # 4. %, ์›, ๊ฑด, ๋ช… ๋’ค์— ๋‚˜์˜ค๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
84
+ text = re.sub(r'(?<=[%์›๊ฑด๋ช…\)])<unk>', '\n', text)
85
+ # 5. ์ˆซ์ž์™€ ์ˆซ์ž ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
86
+ text = re.sub(r'(\d)<unk>(\d)', r'\1\n\2', text)
87
+ # 6. 'ํ˜•'์ด๋ผ๋Š” ๊ธ€์ž์™€ ' |' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
88
+ text = re.sub(r'ํ˜•<unk>(?= \|)', 'ํ˜•\n', text)
89
+ # 7. ๋‚˜๋จธ์ง€ <unk>๋ฅผ ๋ชจ๋‘ ""๋กœ ๋ฐ”๊ฟˆ
90
+ text = text.replace('<unk>', '')
91
+ return text
92
+
93
+ # Second model prediction aihub_deplot
94
+ def predict_model2(image):
95
+ image = image.convert("RGB")
96
+ inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
97
+
98
+ flattened_patches = inputs.flattened_patches.to(device)
99
+ attention_mask = inputs.attention_mask.to(device)
100
+
101
+ model2.eval()
102
+ t5_model.eval()
103
+ with torch.no_grad():
104
+ deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
105
+ generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
106
+ generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
107
+ refined_table = replace_unk(generated_datatable)
108
+ return refined_table
109
+
110
+ #function for converting aihub dataset labeling json file to ko-deplot data table
111
+ def process_json_file(input_file):
112
+ with open(input_file, 'r', encoding='utf-8') as file:
113
+ data = json.load(file)
114
+
115
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
116
+ chart_type = data['metadata']['chart_sub']
117
+ title = data['annotations'][0]['title']
118
+ x_axis = data['annotations'][0]['axis_label']['x_axis']
119
+ y_axis = data['annotations'][0]['axis_label']['y_axis']
120
+ legend = data['annotations'][0]['legend']
121
+ data_labels = data['annotations'][0]['data_label']
122
+ is_legend = data['annotations'][0]['is_legend']
123
+
124
+ # ์›ํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
125
+ formatted_string = f"TITLE | {title} <0x0A> "
126
+ if '๊ฐ€๋กœ' in chart_type:
127
+ if is_legend:
128
+ # ๊ฐ€๋กœ ์ฐจํŠธ ์ฒ˜๋ฆฌ
129
+ formatted_string += " | ".join(legend) + " <0x0A> "
130
+ for i in range(len(y_axis)):
131
+ row = [y_axis[i]]
132
+ for j in range(len(legend)):
133
+ if i < len(data_labels[j]):
134
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
135
+ else:
136
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
137
+ formatted_string += " | ".join(row) + " <0x0A> "
138
+ else:
139
+ # is_legend๊ฐ€ False์ธ ๊ฒฝ์šฐ
140
+ for i in range(len(y_axis)):
141
+ row = [y_axis[i], str(data_labels[0][i])]
142
+ formatted_string += " | ".join(row) + " <0x0A> "
143
+ elif chart_type == "์›ํ˜•":
144
+ # ์›ํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
145
+ if legend:
146
+ used_labels = legend
147
+ else:
148
+ used_labels = x_axis
149
+
150
+ formatted_string += " | ".join(used_labels) + " <0x0A> "
151
+ row = [data_labels[0][i] for i in range(len(used_labels))]
152
+ formatted_string += " | ".join(row) + " <0x0A> "
153
+ elif chart_type == "ํ˜ผํ•ฉํ˜•":
154
+ # ํ˜ผํ•ฉํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
155
+ all_legends = [ann['legend'][0] for ann in data['annotations']]
156
+ formatted_string += " | ".join(all_legends) + " <0x0A> "
157
+
158
+ combined_data = []
159
+ for i in range(len(x_axis)):
160
+ row = [x_axis[i]]
161
+ for ann in data['annotations']:
162
+ if i < len(ann['data_label'][0]):
163
+ row.append(str(ann['data_label'][0][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
164
+ else:
165
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
166
+ combined_data.append(" | ".join(row))
167
+
168
+ formatted_string += " <0x0A> ".join(combined_data) + " <0x0A> "
169
+ else:
170
+ # ๊ธฐํƒ€ ์ฐจํŠธ ์ฒ˜๋ฆฌ
171
+ if is_legend:
172
+ formatted_string += " | ".join(legend) + " <0x0A> "
173
+ for i in range(len(x_axis)):
174
+ row = [x_axis[i]]
175
+ for j in range(len(legend)):
176
+ if i < len(data_labels[j]):
177
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
178
+ else:
179
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
180
+ formatted_string += " | ".join(row) + " <0x0A> "
181
+ else:
182
+ for i in range(len(x_axis)):
183
+ if i < len(data_labels[0]):
184
+ formatted_string += f"{x_axis[i]} | {str(data_labels[0][i])} <0x0A> "
185
+ else:
186
+ formatted_string += f"{x_axis[i]} | <0x0A> " # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
187
+
188
+ # ๋งˆ์ง€๋ง‰ "<0x0A> " ์ œ๊ฑฐ
189
+ formatted_string = formatted_string[:-8]
190
+ return format_output(formatted_string)
191
+
192
+ def chart_data(data):
193
+ datatable = []
194
+ num = len(data)
195
+ for n in range(num):
196
+ title = data[n]['title'] if data[n]['is_title'] else ''
197
+ legend = data[n]['legend'] if data[n]['is_legend'] else ''
198
+ datalabel = data[n]['data_label'] if data[n]['is_datalabel'] else [0]
199
+ unit = data[n]['unit'] if data[n]['is_unit'] else ''
200
+ base = data[n]['base'] if data[n]['is_base'] else ''
201
+ x_axis_title = data[n]['axis_title']['x_axis']
202
+ y_axis_title = data[n]['axis_title']['y_axis']
203
+ x_axis = data[n]['axis_label']['x_axis'] if data[n]['is_axis_label_x_axis'] else [0]
204
+ y_axis = data[n]['axis_label']['y_axis'] if data[n]['is_axis_label_y_axis'] else [0]
205
+
206
+ if len(legend) > 1:
207
+ datalabel = np.array(datalabel).transpose().tolist()
208
+
209
+ datatable.append([title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis])
210
+
211
+ return datatable
212
+
213
+ def datatable(data, chart_type):
214
+ data_table = ''
215
+ num = len(data)
216
+
217
+ if len(data) == 2:
218
+ temp = []
219
+ temp.append(f"๋Œ€์ƒ: {data[0][4]}")
220
+ temp.append(f"์ œ๋ชฉ: {data[0][0]}")
221
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
222
+ temp.append(f"{data[0][5]} | {data[0][1][0]}({data[0][3]}) | {data[1][1][0]}({data[1][3]})")
223
+
224
+ x_axis = data[0][7]
225
+ for idx, x in enumerate(x_axis):
226
+ temp.append(f"{x} | {data[0][2][0][idx]} | {data[1][2][0][idx]}")
227
+
228
+ data_table = '\n'.join(temp)
229
+ else:
230
+ for n in range(num):
231
+ temp = []
232
+
233
+ title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis = data[n]
234
+ legend = [element + f"({unit})" for element in legend]
235
+
236
+ if len(legend) > 1:
237
+ temp.append(f"๋Œ€์ƒ: {base}")
238
+ temp.append(f"์ œ๋ชฉ: {title}")
239
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
240
+ temp.append(f"{x_axis_title} | {' | '.join(legend)}")
241
+
242
+ if chart_type[2] == "์›ํ˜•":
243
+ datalabel = sum(datalabel, [])
244
+ temp.append(f"{' | '.join([str(d) for d in datalabel])}")
245
+ data_table = '\n'.join(temp)
246
+ else:
247
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
248
+ for idx, (x, d) in enumerate(zip(axis, datalabel)):
249
+ temp_d = [str(e) for e in d]
250
+ temp_d = " | ".join(temp_d)
251
+ row = f"{x} | {temp_d}"
252
+ temp.append(row)
253
+ data_table = '\n'.join(temp)
254
+ else:
255
+ temp.append(f"๋Œ€์ƒ: {base}")
256
+ temp.append(f"์ œ๋ชฉ: {title}")
257
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
258
+ temp.append(f"{x_axis_title} | {unit}")
259
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
260
+ datalabel = datalabel[0]
261
+
262
+ for idx, x in enumerate(axis):
263
+ row = f"{x} | {str(datalabel[idx])}"
264
+ temp.append(row)
265
+ data_table = '\n'.join(temp)
266
+
267
+ return data_table
268
+
269
+ #function for converting aihub dataset labeling json file to aihub-deplot data table
270
+ def process_json_file2(input_file):
271
+ with open(input_file, 'r', encoding='utf-8') as file:
272
+ data = json.load(file)
273
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
274
+ chart_multi = data['metadata']['chart_multi']
275
+ chart_main = data['metadata']['chart_main']
276
+ chart_sub = data['metadata']['chart_sub']
277
+ chart_type = [chart_multi, chart_sub, chart_main]
278
+ chart_annotations = data['annotations']
279
+
280
+ charData = chart_data(chart_annotations)
281
+ dataTable = datatable(charData, chart_type)
282
+ return dataTable
283
+
284
+ # RMS
285
+ def _to_float(text): # ๋‹จ์œ„ ๋–ผ๊ณ  ์ˆซ์ž๋งŒ..?
286
+ try:
287
+ if text.endswith("%"):
288
+ # Convert percentages to floats.
289
+ return float(text.rstrip("%")) / 100.0
290
+ else:
291
+ return float(text)
292
+ except ValueError:
293
+ return None
294
+
295
+
296
+ def _get_relative_distance(
297
+ target, prediction, theta = 1.0
298
+ ):
299
+ """Returns min(1, |target-prediction|/|target|)."""
300
+ if not target:
301
+ return int(not prediction)
302
+ distance = min(abs((target - prediction) / target), 1)
303
+ return distance if distance < theta else 1
304
+
305
+ def anls_metric(target: str, prediction: str, theta: float = 0.5):
306
+ edit_distance = editdistance.eval(target, prediction)
307
+ normalize_ld = edit_distance / max(len(target), len(prediction))
308
+ return 1 - normalize_ld if normalize_ld < theta else 0
309
+
310
+ def _permute(values, indexes):
311
+ return tuple(values[i] if i < len(values) else "" for i in indexes)
312
+
313
+
314
+ @dataclasses.dataclass(frozen=True)
315
+ class Table:
316
+ """Helper class for the content of a markdown table."""
317
+
318
+ base: Optional[str] = None
319
+ title: Optional[str] = None
320
+ chartType: Optional[str] = None
321
+ headers: tuple[str, Ellipsis] = dataclasses.field(default_factory=tuple)
322
+ rows: tuple[tuple[str, Ellipsis], Ellipsis] = dataclasses.field(default_factory=tuple)
323
+
324
+ def permuted(self, indexes):
325
+ """Builds a version of the table changing the column order."""
326
+ return Table(
327
+ base=self.base,
328
+ title=self.title,
329
+ chartType=self.chartType,
330
+ headers=_permute(self.headers, indexes),
331
+ rows=tuple(_permute(row, indexes) for row in self.rows),
332
+ )
333
+
334
+ def aligned(
335
+ self, headers, text_theta = 0.5
336
+ ):
337
+ """Builds a column permutation with headers in the most correct order."""
338
+ if len(headers) != len(self.headers):
339
+ raise ValueError(f"Header length {headers} must match {self.headers}.")
340
+ distance = []
341
+ for h2 in self.headers:
342
+ distance.append(
343
+ [
344
+ 1 - anls_metric(h1, h2, text_theta)
345
+ for h1 in headers
346
+ ]
347
+ )
348
+ cost_matrix = np.array(distance)
349
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
350
+ permutation = [idx for _, idx in sorted(zip(col_ind, row_ind))]
351
+ score = (1 - cost_matrix)[permutation[1:], range(1, len(row_ind))].prod()
352
+ return self.permuted(permutation), score
353
+
354
+ def _parse_table(text, transposed = False): # ํ‘œ ์ œ๋ชฉ, ์—ด ์ด๋ฆ„, ํ–‰ ์ฐพ๊ธฐ
355
+ """Builds a table from a markdown representation."""
356
+ lines = text.lower().splitlines()
357
+ if not lines:
358
+ return Table()
359
+
360
+ if lines[0].startswith("๋Œ€์ƒ: "):
361
+ base = lines[0][len("๋Œ€์ƒ: ") :].strip()
362
+ offset = 1 #
363
+ else:
364
+ base = None
365
+ offset = 0
366
+ if lines[1].startswith("์ œ๋ชฉ: "):
367
+ title = lines[1][len("์ œ๋ชฉ: ") :].strip()
368
+ offset = 2 #
369
+ else:
370
+ title = None
371
+ offset = 1
372
+ if lines[2].startswith("์œ ํ˜•: "):
373
+ chartType = lines[2][len("์œ ํ˜•: ") :].strip()
374
+ offset = 3 #
375
+ else:
376
+ chartType = None
377
+
378
+ if len(lines) < offset + 1:
379
+ return Table(base=base, title=title, chartType=chartType)
380
+
381
+ rows = []
382
+ for line in lines[offset:]:
383
+ rows.append(tuple(v.strip() for v in line.split(" | ")))
384
+ if transposed:
385
+ rows = [tuple(row) for row in itertools.zip_longest(*rows, fillvalue="")]
386
+ return Table(base=base, title=title, chartType=chartType, headers=rows[0], rows=tuple(rows[1:]))
387
+
388
+ def _get_table_datapoints(table):
389
+ datapoints = {}
390
+ if table.base is not None:
391
+ datapoints["๋Œ€์ƒ"] = table.base
392
+ if table.title is not None:
393
+ datapoints["์ œ๋ชฉ"] = table.title
394
+ if table.chartType is not None:
395
+ datapoints["์œ ํ˜•"] = table.chartType
396
+ if not table.rows or len(table.headers) <= 1:
397
+ return datapoints
398
+ for row in table.rows:
399
+ for header, cell in zip(table.headers[1:], row[1:]):
400
+ #print(f"{row[0]} {header} >> {cell}")
401
+ datapoints[f"{row[0]} {header}"] = cell #
402
+ return datapoints
403
+
404
+ def _get_datapoint_metric( #
405
+ target,
406
+ prediction,
407
+ text_theta=0.5,
408
+ number_theta=0.1,
409
+ ):
410
+ """Computes a metric that scores how similar two datapoint pairs are."""
411
+ key_metric = anls_metric(
412
+ target[0], prediction[0], text_theta
413
+ )
414
+ pred_float = _to_float(prediction[1]) # ์ˆซ์ž์ธ์ง€ ํ™•์ธ
415
+ target_float = _to_float(target[1])
416
+ if pred_float is not None and target_float:
417
+ return key_metric * (
418
+ 1 - _get_relative_distance(target_float, pred_float, number_theta) # ์ˆซ์ž๋ฉด ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ’ ๊ณ„์‚ฐ
419
+ )
420
+ elif target[1] == prediction[1]:
421
+ return key_metric
422
+ else:
423
+ return key_metric * anls_metric(
424
+ target[1], prediction[1], text_theta
425
+ )
426
+
427
+ def _table_datapoints_precision_recall_f1( # ์ฐ ๊ณ„์‚ฐ
428
+ target_table,
429
+ prediction_table,
430
+ text_theta = 0.5,
431
+ number_theta = 0.1,
432
+ ):
433
+ """Calculates matching similarity between two tables as dicts."""
434
+ target_datapoints = list(_get_table_datapoints(target_table).items())
435
+ prediction_datapoints = list(_get_table_datapoints(prediction_table).items())
436
+ if not target_datapoints and not prediction_datapoints:
437
+ return 1, 1, 1
438
+ if not target_datapoints:
439
+ return 0, 1, 0
440
+ if not prediction_datapoints:
441
+ return 1, 0, 0
442
+ distance = []
443
+ for t, _ in target_datapoints:
444
+ distance.append(
445
+ [
446
+ 1 - anls_metric(t, p, text_theta)
447
+ for p, _ in prediction_datapoints
448
+ ]
449
+ )
450
+ cost_matrix = np.array(distance)
451
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
452
+ score = 0
453
+ for r, c in zip(row_ind, col_ind):
454
+ score += _get_datapoint_metric(
455
+ target_datapoints[r], prediction_datapoints[c], text_theta, number_theta
456
+ )
457
+ if score == 0:
458
+ return 0, 0, 0
459
+ precision = score / len(prediction_datapoints)
460
+ recall = score / len(target_datapoints)
461
+ return precision, recall, 2 * precision * recall / (precision + recall)
462
+
463
+ def table_datapoints_precision_recall_per_point( # ๊ฐ๊ฐ ๊ณ„์‚ฐ...
464
+ targets,
465
+ predictions,
466
+ text_theta = 0.5,
467
+ number_theta = 0.1,
468
+ ):
469
+ """Computes precisin recall and F1 metrics given two flattened tables.
470
+
471
+ Parses each string into a dictionary of keys and values using row and column
472
+ headers. Then we match keys between the two dicts as long as their relative
473
+ levenshtein distance is below a threshold. Values are also compared with
474
+ ANLS if strings or relative distance if they are numeric.
475
+
476
+ Args:
477
+ targets: list of list of strings.
478
+ predictions: list of strings.
479
+ text_theta: relative edit distance above this is set to the maximum of 1.
480
+ number_theta: relative error rate above this is set to the maximum of 1.
481
+
482
+ Returns:
483
+ Dictionary with per-point precision, recall and F1
484
+ """
485
+ assert len(targets) == len(predictions)
486
+ per_point_scores = {"precision": [], "recall": [], "f1": []}
487
+ for pred, target in zip(predictions, targets):
488
+ all_metrics = []
489
+ for transposed in [True, False]:
490
+ pred_table = _parse_table(pred, transposed=transposed)
491
+ target_table = _parse_table(target, transposed=transposed)
492
+
493
+ all_metrics.extend([_table_datapoints_precision_recall_f1(target_table, pred_table, text_theta, number_theta)])
494
+
495
+ p, r, f = max(all_metrics, key=lambda x: x[-1])
496
+ per_point_scores["precision"].append(p)
497
+ per_point_scores["recall"].append(r)
498
+ per_point_scores["f1"].append(f)
499
+ return per_point_scores
500
+
501
+ def table_datapoints_precision_recall( # deplot ์„ฑ๋Šฅ์ง€ํ‘œ
502
+ targets,
503
+ predictions,
504
+ text_theta = 0.5,
505
+ number_theta = 0.1,
506
+ ):
507
+ """Aggregated version of table_datapoints_precision_recall_per_point().
508
+
509
+ Same as table_datapoints_precision_recall_per_point() but returning aggregated
510
+ scores instead of per-point scores.
511
+
512
+ Args:
513
+ targets: list of list of strings.
514
+ predictions: list of strings.
515
+ text_theta: relative edit distance above this is set to the maximum of 1.
516
+ number_theta: relative error rate above this is set to the maximum of 1.
517
+
518
+ Returns:
519
+ Dictionary with aggregated precision, recall and F1
520
+ """
521
+ score_dict = table_datapoints_precision_recall_per_point(
522
+ targets, predictions, text_theta, number_theta
523
+ )
524
+ return {
525
+ "table_datapoints_precision": (
526
+ sum(score_dict["precision"]) / len(targets)
527
+ ),
528
+ "table_datapoints_recall": (
529
+ sum(score_dict["recall"]) / len(targets)
530
+ ),
531
+ "table_datapoints_f1": sum(score_dict["f1"]) / len(targets),
532
+ }
533
+
534
+ def evaluate_rms(generated_table,label_table):
535
+ predictions=[generated_table]
536
+ targets=[label_table]
537
+ RMS = table_datapoints_precision_recall(targets, predictions)
538
+ return RMS
539
+
540
+ def ko_deplot_convert_to_dataframe(generated_table_str):
541
+ lines = generated_table_str.strip().split(" \n")
542
+ headers=[]
543
+ data=[]
544
+ for i in range(len(lines[1].split(" | "))):
545
+ headers.append(f"{i}")
546
+ for line in lines[1:len(lines)-1]:
547
+ data.append(line.split("| "))
548
+ df = pd.DataFrame(data, columns=headers)
549
+ return df
550
+
551
+ def ko_deplot_convert_to_dataframe2(label_table_str):
552
+ lines = label_table_str.strip().split(" \n")
553
+ headers=[]
554
+ data=[]
555
+ for i in range(len(lines[1].split(" | "))):
556
+ headers.append(f"{i}")
557
+ for line in lines[1:]:
558
+ data.append(line.split("| "))
559
+ df = pd.DataFrame(data, columns=headers)
560
+ return df
561
+
562
+ def aihub_deplot_convert_to_dataframe(table_str):
563
+ lines = table_str.strip().split("\n")
564
+ headers = []
565
+ if(len(lines[3].split(" | "))>len(lines[4].split(" | "))):
566
+ category=lines[3].split(" | ")
567
+ del category[0]
568
+ value=lines[4].split(" | ")
569
+ df=pd.DataFrame({"๋ฒ”๋ก€":category,"๊ฐ’":value})
570
+ return df
571
+ else:
572
+ for i in range(len(lines[3].split(" | "))):
573
+ headers.append(f"{i}")
574
+ data = [line.split(" | ") for line in lines[3:]]
575
+ df = pd.DataFrame(data, columns=headers)
576
+ return df
577
+
578
+ class Highlighter:
579
+ def __init__(self):
580
+ self.row = 0
581
+ self.col = 0
582
+
583
+ def compare_and_highlight(self, pred_table_elem, target_table, pred_table_row, props=''):
584
+ if self.row >= pred_table_row:
585
+ self.col += 1
586
+ self.row = 0
587
+ if pred_table_elem != target_table.iloc[self.row, self.col]:
588
+ self.row += 1
589
+ return props
590
+ else:
591
+ self.row += 1
592
+ return None
593
+
594
+ # 1. ๋ฐ์ดํ„ฐ ๋กœ๋“œ
595
+ aihub_deplot_result_df = pd.read_csv('./aihub_deplot_result.csv')
596
+ ko_deplot_result= './ko-deplot-base-pred-epoch1-refinetuning.json'
597
+
598
+ # 2. ์ฒดํฌํ•ด์•ผ ํ•˜๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋กœ๋“œ
599
+ def load_image_checklist(file):
600
+ with open(file, 'r') as f:
601
+ #image_names = [f'"{line.strip()}"' for line in f]
602
+ image_names = f.read().splitlines()
603
+ return image_names
604
+
605
+ # 3. ํ˜„์žฌ ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
606
+ current_index = 0
607
+ image_names = []
608
+ def show_image(current_idx):
609
+ image_name=image_names[current_idx]
610
+ image_path = f"./images/{image_name}.jpg"
611
+ if not os.path.exists(image_path):
612
+ raise FileNotFoundError(f"Image file not found: {image_path}")
613
+ return Image.open(image_path)
614
+
615
+ # 4. ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
616
+ def non_real_time_check(file):
617
+ highlighter1 = Highlighter()
618
+ highlighter2 = Highlighter()
619
+ #global image_names, current_index
620
+ #image_names = load_image_checklist(file)
621
+ #current_index = 0
622
+ #image=show_image(current_index)
623
+ file_name =image_names[current_index].replace("Source","Label")
624
+
625
+ json_path="./ko_deplot_labeling_data.json"
626
+ with open(json_path, 'r', encoding='utf-8') as file:
627
+ json_data = json.load(file)
628
+ for key, value in json_data.items():
629
+ if key == file_name:
630
+ ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
631
+ ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].replace("TITLE | ","์ œ๋ชฉ:")
632
+ break
633
+
634
+ ko_deplot_rms_path="./ko_deplot_rms.txt"
635
+
636
+ with open(ko_deplot_rms_path,'r',encoding='utf-8') as file:
637
+ lines=file.readlines()
638
+ flag=0
639
+ for line in lines:
640
+ parts=line.strip().split(", ")
641
+ if(len(parts)==2 and parts[0]==image_names[current_index]):
642
+ ko_deplot_rms=parts[1]
643
+ flag=1
644
+ break
645
+ if(flag==0):
646
+ ko_deplot_rms="none"
647
+ ko_deplot_generated_title,ko_deplot_generated_table=ko_deplot_display_results(current_index)
648
+ aihub_deplot_generated_table,aihub_deplot_label_table,aihub_deplot_generated_title,aihub_deplot_label_title=aihub_deplot_display_results(current_index)
649
+ #ko_deplot_RMS=evaluate_rms(ko_deplot_generated_table,ko_deplot_labeling_str)
650
+ aihub_deplot_RMS=evaluate_rms(aihub_deplot_generated_table,aihub_deplot_label_table)
651
+
652
+
653
+ if flag == 1:
654
+ value = [round(float(ko_deplot_rms), 1)]
655
+ else:
656
+ value = [0]
657
+
658
+ ko_deplot_score_table = pd.DataFrame({
659
+ 'category': ['f1'],
660
+ 'value': value
661
+ })
662
+
663
+ aihub_deplot_score_table=pd.DataFrame({
664
+ 'category': ['precision', 'recall', 'f1'],
665
+ 'value': [
666
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
667
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
668
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
669
+ ]
670
+ })
671
+ ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
672
+ aihub_deplot_generated_df=aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table)
673
+ ko_deplot_labeling_df=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
674
+ aihub_deplot_labeling_df=aihub_deplot_convert_to_dataframe(aihub_deplot_label_table)
675
+
676
+ ko_deplot_generated_df_row=ko_deplot_generated_df.shape[0]
677
+ aihub_deplot_generated_df_row=aihub_deplot_generated_df.shape[0]
678
+
679
+
680
+ styled_ko_deplot_table=ko_deplot_generated_df.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_labeling_df,pred_table_row=ko_deplot_generated_df_row,props='color:red')
681
+
682
+
683
+ styled_aihub_deplot_table=aihub_deplot_generated_df.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_labeling_df,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
684
+
685
+ #return ko_deplot_convert_to_dataframe(ko_deplot_generated_table), aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table), aihub_deplot_convert_to_dataframe(label_table), ko_deplot_score_table, aihub_deplot_score_table
686
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(ko deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_labeling_df,label=ko_deplot_label_title+"(ko deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"), gr.DataFrame(aihub_deplot_labeling_df,label=aihub_deplot_label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table
687
+
688
+ def ko_deplot_display_results(index):
689
+ filename=image_names[index]+".jpg"
690
+ with open(ko_deplot_result, 'r', encoding='utf-8') as f:
691
+ data = json.load(f)
692
+ for entry in data:
693
+ if entry['filename'].endswith(filename):
694
+ #return entry['table']
695
+ parts=entry['table'].split("\n",1)
696
+ return parts[0].replace("TITLE | ","์ œ๋ชฉ:"),entry['table']
697
+
698
+ def aihub_deplot_display_results(index):
699
+ if index < 0 or index >= len(image_names):
700
+ return "Index out of range", None, None
701
+ image_name = image_names[index]
702
+ image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == image_name]
703
+ if not image_row.empty:
704
+ generated_table = image_row['generated_table'].values[0]
705
+ generated_title=generated_table.split("\n")[1]
706
+ label_table = image_row['label_table'].values[0]
707
+ label_title=label_table.split("\n")[1]
708
+ return generated_table, label_table, generated_title, label_title
709
+ else:
710
+ return "No results found for the image", None, None
711
+
712
+ def previous_image():
713
+ global current_index
714
+ if current_index>0:
715
+ current_index-=1
716
+ image=show_image(current_index)
717
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
718
+
719
+ def next_image():
720
+ global current_index
721
+ if current_index<len(image_names)-1:
722
+ current_index+=1
723
+ image=show_image(current_index)
724
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
725
+
726
+ def real_time_check(image_file):
727
+ highlighter1 = Highlighter()
728
+ highlighter2 = Highlighter()
729
+ image = Image.open(image_file)
730
+ result_model1 = predict_model1(image)
731
+ parts=result_model1.split("\n")
732
+ del parts[-1]
733
+ result_model1="\n".join(parts)
734
+ ko_deplot_generated_title=result_model1.split("\n")[0].split(" | ")[1]
735
+ ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
736
+
737
+ result_model2 = predict_model2(image)
738
+ aihub_deplot_generated_title=result_model2.split("\n")[1].split(":")[1]
739
+ aihub_deplot_table=aihub_deplot_convert_to_dataframe(result_model2)
740
+ image_base_name = os.path.basename(image_file.name).replace("Source","Label")
741
+ file_name, _ = os.path.splitext(image_base_name)
742
+
743
+ aihub_labeling_data_json="./labeling_data/"+file_name+".json"
744
+
745
+ json_path="./ko_deplot_labeling_data.json"
746
+ with open(json_path, 'r', encoding='utf-8') as file:
747
+ json_data = json.load(file)
748
+ for key, value in json_data.items():
749
+ if key == file_name:
750
+ ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
751
+ ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].split(" | ")[1]
752
+ break
753
+
754
+ ko_deplot_label_table=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
755
+
756
+ aihub_deplot_labeling_str=process_json_file2(aihub_labeling_data_json)
757
+ aihub_deplot_label_title=aihub_deplot_labeling_str.split("\n")[1].split(":")[1]
758
+ aihub_deplot_label_table=aihub_deplot_convert_to_dataframe(aihub_deplot_labeling_str)
759
+
760
+ ko_deplot_RMS=evaluate_rms(result_model1,ko_deplot_labeling_str)
761
+ aihub_deplot_RMS=evaluate_rms(result_model2,aihub_deplot_labeling_str)
762
+
763
+ ko_deplot_score_table=pd.DataFrame({
764
+ 'category': ['precision', 'recall', 'f1'],
765
+ 'value': [
766
+ round(ko_deplot_RMS['table_datapoints_precision'],1),
767
+ round(ko_deplot_RMS['table_datapoints_recall'],1),
768
+ round(ko_deplot_RMS['table_datapoints_f1'],1)
769
+ ]
770
+ })
771
+ aihub_deplot_score_table=pd.DataFrame({
772
+ 'category': ['precision', 'recall', 'f1'],
773
+ 'value': [
774
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
775
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
776
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
777
+ ]
778
+ })
779
+
780
+ ko_deplot_generated_df_row=ko_deplot_table.shape[0]
781
+ aihub_deplot_generated_df_row=aihub_deplot_table.shape[0]
782
+ styled_ko_deplot_table=ko_deplot_table.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_label_table,pred_table_row=ko_deplot_generated_df_row,props='color:red')
783
+ styled_aihub_deplot_table=aihub_deplot_table.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_label_table,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
784
+
785
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(kodeplot ์ถ”๋ก  ๊ฒฐ๊ณผ)") , gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_label_table,label=ko_deplot_label_title+"(kodeplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),gr.DataFrame(aihub_deplot_label_table,label=aihub_deplot_label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table
786
+ #return ko_deplot_table,aihub_deplot_table,aihub_deplot_label_table,ko_deplot_score_table,aihub_deplot_score_table
787
+ def inference(mode,image_uploader,file_uploader):
788
+ if(mode=="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"):
789
+ ko_deplot_table, aihub_deplot_table, ko_deplot_label_table,aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table = real_time_check(image_uploader)
790
+ return ko_deplot_table, aihub_deplot_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table
791
+ else:
792
+ styled_ko_deplot_table, styled_aihub_deplot_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table =non_real_time_check(file_uploader)
793
+ return styled_ko_deplot_table, styled_aihub_deplot_table, ko_deplot_label_table,aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table
794
+
795
+ def interface_selector(selector):
796
+ if selector == "์ด๋ฏธ์ง€ ์—…๋กœ๋“œ":
797
+ return gr.update(visible=True),gr.update(visible=False),gr.State("image_upload"),gr.update(visible=False),gr.update(visible=False)
798
+ elif selector == "ํŒŒ์ผ ์—…๋กœ๋“œ":
799
+ return gr.update(visible=False),gr.update(visible=True),gr.State("file_upload"), gr.update(visible=True),gr.update(visible=True)
800
+
801
+ def file_selector(selector):
802
+ if selector == "low score ์ฐจํŠธ":
803
+ return gr.File("./new_bottom_20_percent_images.txt")
804
+ elif selector == "high score ์ฐจํŠธ":
805
+ return gr.File("./new_top_20_percent_images.txt")
806
+
807
+ def update_results(model_type):
808
+ if "ko_deplot" == model_type:
809
+ return gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False)
810
+ elif "aihub_deplot" == model_type:
811
+ return gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=True)
812
+ else:
813
+ return gr.update(visible=True), gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
814
+
815
+ def display_image(image_file):
816
+ image=Image.open(image_file)
817
+ return image, os.path.basename(image_file)
818
+
819
+ def display_image_in_file(image_checklist):
820
+ global image_names, current_index
821
+ image_names = load_image_checklist(image_checklist)
822
+ image=show_image(current_index)
823
+ return image,image_names[current_index]
824
+
825
+ def update_file_based_on_chart_type(chart_type, all_file_path):
826
+ with open(all_file_path, 'r', encoding='utf-8') as file:
827
+ lines = file.readlines()
828
+ filtered_lines=[]
829
+ if chart_type == "์ „์ฒด":
830
+ filtered_lines = lines
831
+ elif chart_type == "์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
832
+ filtered_lines = [line for line in lines if "_horizontal bar_standard" in line]
833
+ elif chart_type=="๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
834
+ filtered_lines = [line for line in lines if "_horizontal bar_accumulation" in line]
835
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
836
+ filtered_lines = [line for line in lines if "_horizontal bar_100per accumulation" in line]
837
+ elif chart_type=="์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
838
+ filtered_lines = [line for line in lines if "_vertical bar_standard" in line]
839
+ elif chart_type=="๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
840
+ filtered_lines = [line for line in lines if "_vertical bar_accumulation" in line]
841
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
842
+ filtered_lines = [line for line in lines if "_vertical bar_100per accumulation" in line]
843
+ elif chart_type=="์„ ํ˜•":
844
+ filtered_lines = [line for line in lines if "_line_standard" in line]
845
+ elif chart_type=="์›ํ˜•":
846
+ filtered_lines = [line for line in lines if "_pie_standard" in line]
847
+ elif chart_type=="๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•":
848
+ filtered_lines = [line for line in lines if "_etc_radial" in line]
849
+ elif chart_type=="๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•":
850
+ filtered_lines = [line for line in lines if "_etc_mix" in line]
851
+ # ์ƒˆ๋กœ์šด ํŒŒ์ผ์— ๊ธฐ๋ก
852
+ new_file_path = "./filtered_chart_images.txt"
853
+ with open(new_file_path, 'w', encoding='utf-8') as file:
854
+ file.writelines(filtered_lines)
855
+
856
+ return new_file_path
857
+
858
+ def handle_chart_type_change(chart_type,all_file_path):
859
+ new_file_path = update_file_based_on_chart_type(chart_type, all_file_path)
860
+ global image_names, current_index
861
+ image_names = load_image_checklist(new_file_path)
862
+ current_index=0
863
+ image=show_image(current_index)
864
+ return image,image_names[current_index]
865
+
866
+ with gr.Blocks() as iface:
867
+ mode=gr.State("image_upload")
868
+ with gr.Row():
869
+ with gr.Column():
870
+ #mode_label=gr.Text("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ๊ฐ€ ์„ ํƒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
871
+ upload_option = gr.Radio(choices=["์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", "ํŒŒ์ผ ์—…๋กœ๋“œ"], value="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", label="์—…๋กœ๋“œ ์˜ต์…˜")
872
+ #with gr.Row():
873
+ #image_button = gr.Button("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
874
+ #file_button = gr.Button("ํŒŒ์ผ ์—…๋กœ๋“œ")
875
+
876
+ # ์ด๋ฏธ์ง€์™€ ํŒŒ์ผ ์—…๋กœ๋“œ ์ปดํฌ๋„ŒํŠธ (์ดˆ๊ธฐ์—๋Š” ์ˆจ๊น€ ์ƒํƒœ)
877
+ # global image_uploader,file_uploader
878
+ image_uploader= gr.File(file_count="single",file_types=["image"],visible=True)
879
+ file_uploader= gr.File(file_count="single", file_types=[".txt"], visible=False)
880
+ file_upload_option=gr.Radio(choices=["low score ์ฐจํŠธ","high score ์ฐจํŠธ"],label="ํŒŒ์ผ ์—…๋กœ๋“œ ์˜ต์…˜",visible=False)
881
+ chart_type = gr.Dropdown(["์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•", "์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","์„ ํ˜•", "์›ํ˜•", "๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•", "๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•", "์ „์ฒด"], label="Chart Type", value="all")
882
+ model_type=gr.Dropdown(["ko_deplot","aihub_deplot","all"],label="model")
883
+ image_displayer=gr.Image(visible=True)
884
+ with gr.Row():
885
+ pre_button=gr.Button("์ด์ „",interactive="False")
886
+ next_button=gr.Button("๋‹ค์Œ")
887
+ image_name=gr.Text("์ด๋ฏธ์ง€ ์ด๋ฆ„",visible=False)
888
+ #image_button.click(interface_selector, inputs=gr.State("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"), outputs=[image_uploader,file_uploader,mode,mode_label,image_name])
889
+ #file_button.click(interface_selector, inputs=gr.State("ํŒŒ์ผ ์—…๋กœ๋“œ"), outputs=[image_uploader, file_uploader,mode,mode_label,image_name])
890
+ inference_button=gr.Button("์ถ”๋ก ")
891
+ with gr.Column():
892
+ ko_deplot_generated_table=gr.DataFrame(visible=False,label="ko-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
893
+ aihub_deplot_generated_table=gr.DataFrame(visible=False,label="aihub-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
894
+ with gr.Column():
895
+ ko_deplot_label_table=gr.DataFrame(visible=False,label="ko-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
896
+ aihub_deplot_label_table=gr.DataFrame(visible=False,label="aihub-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
897
+ with gr.Column():
898
+ ko_deplot_score_table=gr.DataFrame(visible=False,label="ko_deplot ์ ์ˆ˜")
899
+ aihub_deplot_score_table=gr.DataFrame(visible=False,label="aihub_deplot ์ ์ˆ˜")
900
+ model_type.change(
901
+ update_results,
902
+ inputs=[model_type],
903
+ outputs=[ko_deplot_generated_table,ko_deplot_score_table,aihub_deplot_generated_table,aihub_deplot_score_table,ko_deplot_label_table,aihub_deplot_label_table]
904
+ )
905
+
906
+ upload_option.change(
907
+ interface_selector,
908
+ inputs=[upload_option],
909
+ outputs=[image_uploader, file_uploader, mode, image_name,file_upload_option]
910
+ )
911
+
912
+ file_upload_option.change(
913
+ file_selector,
914
+ inputs=[file_upload_option],
915
+ outputs=[file_uploader]
916
+ )
917
+
918
+ chart_type.change(handle_chart_type_change, inputs=[chart_type,file_uploader],outputs=[image_displayer,image_name])
919
+ image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
920
+ file_uploader.change(display_image_in_file,inputs=[file_uploader],outputs=[image_displayer,image_name])
921
+ pre_button.click(previous_image, outputs=[image_displayer,image_name,pre_button,next_button])
922
+ next_button.click(next_image, outputs=[image_displayer,image_name,pre_button,next_button])
923
+ inference_button.click(inference,inputs=[upload_option,image_uploader,file_uploader],outputs=[ko_deplot_generated_table, aihub_deplot_generated_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table])
924
+
925
+
926
+ if __name__ == "__main__":
927
+ print("Launching Gradio interface...")
928
+ sys.stdout.flush() # stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
929
+ iface.launch(share=True)
930
+ time.sleep(2) # Gradio URL์ด ์ถœ๋ ฅ๋  ๋•Œ๊นŒ์ง€ ์ž ์‹œ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.
931
+ sys.stdout.flush() # ๋‹ค์‹œ stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
932
+ # Gradio๊ฐ€ ์ œ๊ณตํ•˜๋Š” URLs์„ ํŒŒ์ผ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
933
+ with open("gradio_url.log", "w") as f:
934
+ print(iface.local_url, file=f)
935
+ print(iface.share_url, file=f)