Liu Yiwen commited on
Commit
60dbd41
·
1 Parent(s): 15549a1

更新了处理逻辑

Browse files
__pycache__/config.cpython-311.pyc ADDED
Binary file (751 Bytes). View file
 
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
app.py CHANGED
@@ -15,8 +15,9 @@ from hffs.fs import HfFileSystem
15
  from datasets import Features, Image, Audio, Sequence
16
  from typing import List, Tuple, Callable
17
 
18
- from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
19
  from comm_utils import save_to_file, send_msg_to_server, save_score
 
20
 
21
  class AppError(RuntimeError):
22
  pass
@@ -183,12 +184,12 @@ def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int
183
  df = copy.deepcopy(df_)
184
 
185
  unsupported_columns = []
186
- if dataset == 'Salesforce/lotsa_data':
187
  # 对Salesforce/lotsa_data数据集进行特殊处理
188
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
189
  return df, max_page, info
190
 
191
- elif dataset == 'YY26/TS_DATASETS':
192
  # 对YY26/TS_DATASETS数据集进行特殊处理
193
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
194
  return df, max_page, info
@@ -242,7 +243,7 @@ def process_salesforce_data(dataset: str, config: str, split: str, page: List[st
242
  with gr.Blocks() as demo:
243
  # 初始化组件
244
  gr.Markdown("A tool for interactive observation of lotsa dataset, extended from lhoestq/datasets-explorer")
245
- cp_dataset = gr.Textbox("YY26/TS_DATASETS", label="Pick a dataset", interactive=False)
246
  cp_go = gr.Button("Explore")
247
  cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False)
248
  cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False)
@@ -267,6 +268,7 @@ with gr.Blocks() as demo:
267
  statistics_textbox = gr.DataFrame()
268
  with gr.Column(scale=3):
269
  plot = gr.Plot()
 
270
  with gr.Row():
271
  user_input_box = gr.Textbox(label="question", interactive=False)
272
  user_output_box = gr.Textbox(label="answer", interactive=False)
@@ -274,7 +276,7 @@ with gr.Blocks() as demo:
274
  # "statistics_textbox": statistics_textbox,
275
  # "user_input_box": user_input_box,
276
  # "plot": plot})
277
- score_slider = gr.Slider(1, 5, 1, label="Score for answer", interactive=True)
278
  with gr.Row():
279
  with gr.Column(scale=2):
280
  user_submit_button = gr.Button("submit", interactive=True)
@@ -295,25 +297,28 @@ with gr.Blocks() as demo:
295
  def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]=['all']) -> dict:
296
  try:
297
  ret = {}
298
- if dataset == 'Salesforce/lotsa_data':
299
- # 对Salesforce/lotsa_data数据集进行特殊处理
300
  if type(page) == str:
301
  page = [page]
302
  df_list, id_list = process_salesforce_data(dataset, config, split, page, sub_targets)
303
  ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
304
  ret[plot] = gr.update(value=create_plot(df_list, id_list))
305
- elif dataset == 'YY26/TS_DATASETS':
306
  df, max_page, info = get_page(dataset, config, split, page)
307
- ret[qusetion_id_box] = gr.update(value = df['num'][0])
308
- # TODO: 修改lotsa_config的读取逻辑
309
- lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
310
- start_index, end_index = df['start_index'][0], df['end_index'][0]
311
- # lotsa_subtargets = eval(df['target_id'][0])
312
- df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
313
- ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=[start_index, end_index]))
314
- ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=[start_index, end_index]))
315
- ret[user_input_box] = gr.update(value=df['question'][0])
316
- ret[user_output_box] = gr.update(value=df['answer'][0])
 
 
 
 
317
  ret[submit_info_box] = gr.update(value="")
318
  else:
319
  markdown_result, max_page, info = get_page(dataset, config, split, page)
@@ -385,13 +390,14 @@ with gr.Blocks() as demo:
385
  statistics_textbox, plot,
386
  qusetion_id_box,
387
  user_input_box, user_output_box,
388
- submit_info_box]
 
389
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
390
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
391
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
392
  cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
393
  cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
394
- user_submit_button.click(save_score, inputs=["none", qusetion_id_box, score_slider], outputs=[submit_info_box])
395
  # select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
396
 
397
 
@@ -401,4 +407,4 @@ if __name__ == "__main__":
401
  host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
402
  # import subprocess
403
  # subprocess.Popen(["python", "test_server.py"])
404
- uvicorn.run(app, host=host, port=7860)
 
15
  from datasets import Features, Image, Audio, Sequence
16
  from typing import List, Tuple, Callable
17
 
18
+ from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot, get_question_info
19
  from comm_utils import save_to_file, send_msg_to_server, save_score
20
+ from config import *
21
 
22
  class AppError(RuntimeError):
23
  pass
 
184
  df = copy.deepcopy(df_)
185
 
186
  unsupported_columns = []
187
+ if dataset == TARGET_DATASET:
188
  # 对Salesforce/lotsa_data数据集进行特殊处理
189
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
190
  return df, max_page, info
191
 
192
+ elif dataset == BENCHMARK_DATASET:
193
  # 对YY26/TS_DATASETS数据集进行特殊处理
194
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
195
  return df, max_page, info
 
243
  with gr.Blocks() as demo:
244
  # 初始化组件
245
  gr.Markdown("A tool for interactive observation of lotsa dataset, extended from lhoestq/datasets-explorer")
246
+ cp_dataset = gr.Textbox(BENCHMARK_DATASET, label="Pick a dataset", interactive=False)
247
  cp_go = gr.Button("Explore")
248
  cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False)
249
  cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False)
 
268
  statistics_textbox = gr.DataFrame()
269
  with gr.Column(scale=3):
270
  plot = gr.Plot()
271
+ question_info_textbox = gr.DataFrame()
272
  with gr.Row():
273
  user_input_box = gr.Textbox(label="question", interactive=False)
274
  user_output_box = gr.Textbox(label="answer", interactive=False)
 
276
  # "statistics_textbox": statistics_textbox,
277
  # "user_input_box": user_input_box,
278
  # "plot": plot})
279
+ score_slider = gr.Slider(1, 5, 1, step=0.5, label="Score for answer", interactive=True)
280
  with gr.Row():
281
  with gr.Column(scale=2):
282
  user_submit_button = gr.Button("submit", interactive=True)
 
297
  def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]=['all']) -> dict:
298
  try:
299
  ret = {}
300
+ if dataset == TARGET_DATASET:
 
301
  if type(page) == str:
302
  page = [page]
303
  df_list, id_list = process_salesforce_data(dataset, config, split, page, sub_targets)
304
  ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
305
  ret[plot] = gr.update(value=create_plot(df_list, id_list))
306
+ elif dataset == BENCHMARK_DATASET:
307
  df, max_page, info = get_page(dataset, config, split, page)
308
+ question_info = get_question_info(df)
309
+ ret[qusetion_id_box] = gr.update(value = df[COLUMN_ID][0])
310
+
311
+ lotsa_config, lotsa_split, lotsa_page = str(df[COLUMN_SOURCE][0]).split('/')[-1], 'train', eval(df[COLUMN_TS_ID][0])
312
+ start_index, end_index = df[COLUMN_START_INDEX][0], df[COLUMN_END_INDEX][0]
313
+ interval = None if np.isnan(start_index) or np.isnan(end_index) else [start_index, end_index]
314
+ lotsa_subtargets = eval(df[COLUMN_TARGET_ID][0])
315
+ df_list, id_list = process_salesforce_data(TARGET_DATASET, lotsa_config, lotsa_split, lotsa_page, lotsa_subtargets)
316
+
317
+ ret[question_info_textbox] = gr.update(value=question_info)
318
+ ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=interval))
319
+ ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=interval))
320
+ ret[user_input_box] = gr.update(value=df[COLUMN_QUESTION][0])
321
+ ret[user_output_box] = gr.update(value=df[COLUMN_ANSWER][0])
322
  ret[submit_info_box] = gr.update(value="")
323
  else:
324
  markdown_result, max_page, info = get_page(dataset, config, split, page)
 
390
  statistics_textbox, plot,
391
  qusetion_id_box,
392
  user_input_box, user_output_box,
393
+ submit_info_box,
394
+ question_info_textbox]
395
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
396
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
397
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
398
  cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
399
  cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
400
+ user_submit_button.click(save_score, inputs=[score_slider, qusetion_id_box, score_slider], outputs=[submit_info_box])
401
  # select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
402
 
403
 
 
407
  host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
408
  # import subprocess
409
  # subprocess.Popen(["python", "test_server.py"])
410
+ # uvicorn.run(app, host=host, port=7860)
config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BENCHMARK_DATASET = 'YY26/TS_benchmark'
2
+ TARGET_DATASET = 'Salesforce/lotsa_data'
3
+
4
+ # columns
5
+ COLUMN_ID = 'id'
6
+ COLUMN_TS_ID = 'ts_id'
7
+ COLUMN_TARGET_ID = 'target_id'
8
+ COLUMN_START_INDEX = 'start_index'
9
+ COLUMN_END_INDEX = 'end_index'
10
+ COLUMN_QA_TYPE = 'qa_type'
11
+ COLUMN_TASK_TYPE = 'task_type'
12
+ COLUMN_QUESTION = 'question'
13
+ COLUMN_OPTION = 'option'
14
+ COLUMN_ANSWER = 'answer'
15
+ COLUMN_DOMAIN = 'domain'
16
+ COLUMN_SOURCE = 'source'
17
+ COLUMN_LOCAL_OVERALL = 'local_overall'
18
+ COLUMN_OPTIONS = 'options'
score.json CHANGED
@@ -23,5 +23,10 @@
23
  "user_id": 2.86,
24
  "question_id": "2",
25
  "score": 2.86
 
 
 
 
 
26
  }
27
  ]
 
23
  "user_id": 2.86,
24
  "question_id": "2",
25
  "score": 2.86
26
+ },
27
+ {
28
+ "user_id": 3,
29
+ "question_id": "1",
30
+ "score": 3
31
  }
32
  ]
utils.py CHANGED
@@ -10,6 +10,8 @@ import pandas as pd
10
  import plotly.graph_objects as go
11
  import numpy as np
12
 
 
 
13
 
14
  def ndarray_to_base64(ndarray):
15
  """
@@ -66,7 +68,8 @@ def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=N
66
  y=df[column],
67
  mode='lines',
68
  name=f"item_{df_id} - {column}",
69
- visible=True if i == 0 else 'legendonly'
 
70
  ))
71
 
72
  # 配置图例
@@ -135,6 +138,14 @@ def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
135
  df.drop(columns=['past_feat_dynamic_real'], inplace=True)
136
  return df
137
 
 
 
 
 
 
 
 
 
138
  if __name__ == '__main__':
139
 
140
  # 创建测试数据
 
10
  import plotly.graph_objects as go
11
  import numpy as np
12
 
13
+ from config import *
14
+
15
 
16
  def ndarray_to_base64(ndarray):
17
  """
 
68
  y=df[column],
69
  mode='lines',
70
  name=f"item_{df_id} - {column}",
71
+ # visible=True if i == 0 else 'legendonly'
72
+ visible=True
73
  ))
74
 
75
  # 配置图例
 
138
  df.drop(columns=['past_feat_dynamic_real'], inplace=True)
139
  return df
140
 
141
+ def get_question_info(df: pd.DataFrame) -> pd.DataFrame:
142
+ """
143
+ 从数据集中提取问题信息。
144
+ """
145
+ question_info = df[[COLUMN_DOMAIN, COLUMN_SOURCE, COLUMN_QA_TYPE, COLUMN_TASK_TYPE]]
146
+ question_info = question_info.drop_duplicates()
147
+ return question_info
148
+
149
  if __name__ == '__main__':
150
 
151
  # 创建测试数据