Liu Yiwen commited on
Commit
1c081e2
·
1 Parent(s): 40e5362

增加了根据起止位置绘图和统计的功能

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-311.pyc +0 -0
  2. app.py +8 -6
  3. utils.py +9 -3
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
app.py CHANGED
@@ -267,9 +267,9 @@ with gr.Blocks() as demo:
267
  with gr.Column(scale=3):
268
  plot = gr.Plot()
269
  with gr.Row():
270
- user_input_box = gr.Textbox(placeholder="输入一些内容", label="输入", lines=5, interactive=True)
271
- user_output_box = gr.Textbox(label="回答", lines=5, interactive=False)
272
- user_io_buttom = gr.Button("发送", interactive=True)
273
  # componets.append({"select_sample_box": select_sample_box,
274
  # "statistics_textbox": statistics_textbox,
275
  # "user_input_box": user_input_box,
@@ -299,11 +299,13 @@ with gr.Blocks() as demo:
299
  ret[plot] = gr.update(value=create_plot(df_list, id_list))
300
  elif dataset == 'YY26/TS_DATASETS':
301
  df, max_page, info = get_page(dataset, config, split, page)
 
302
  lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
 
303
  # lotsa_subtargets = eval(df['target_id'][0])
304
  df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
305
- ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
306
- ret[plot] = gr.update(value=create_plot(df_list, id_list))
307
  ret[user_input_box] = gr.update(value=df['question'][0])
308
  ret[user_output_box] = gr.update(value=df['answer'][0])
309
  else:
@@ -380,7 +382,7 @@ with gr.Blocks() as demo:
380
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
381
  cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
382
  cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
383
- user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
384
  # select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
385
 
386
 
 
267
  with gr.Column(scale=3):
268
  plot = gr.Plot()
269
  with gr.Row():
270
+ user_input_box = gr.Textbox(label="question", interactive=False)
271
+ user_output_box = gr.Textbox(label="answer", interactive=False)
272
+ # user_io_buttom = gr.Button("发送", interactive=True)
273
  # componets.append({"select_sample_box": select_sample_box,
274
  # "statistics_textbox": statistics_textbox,
275
  # "user_input_box": user_input_box,
 
299
  ret[plot] = gr.update(value=create_plot(df_list, id_list))
300
  elif dataset == 'YY26/TS_DATASETS':
301
  df, max_page, info = get_page(dataset, config, split, page)
302
+ # TODO: 修改lotsa_config的读取逻辑
303
  lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
304
+ start_index, end_index = df['start_index'][0], df['end_index'][0]
305
  # lotsa_subtargets = eval(df['target_id'][0])
306
  df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
307
+ ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=[start_index, end_index]))
308
+ ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=[start_index, end_index]))
309
  ret[user_input_box] = gr.update(value=df['question'][0])
310
  ret[user_output_box] = gr.update(value=df['answer'][0])
311
  else:
 
382
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
383
  cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
384
  cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
385
+ # user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
386
  # select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
387
 
388
 
utils.py CHANGED
@@ -52,12 +52,14 @@ def flatten_ndarray_column(df, column_name, rows_to_include):
52
 
53
  return df
54
 
55
- def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
56
  """
57
  创建一个包含所有传入 DataFrame 的线图。
58
  """
59
  fig = go.Figure()
60
  for df, df_id in zip(dfs, ids):
 
 
61
  for i, column in enumerate(df.columns[1:]):
62
  fig.add_trace(go.Scatter(
63
  x=df[df.columns[0]],
@@ -82,13 +84,16 @@ def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
82
  )
83
  return fig
84
 
85
- def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
86
  """
87
  计算数据集列表的统计信息。
88
  """
89
  stats_list = []
90
 
91
  for df, id in zip(dfs, ids):
 
 
 
92
  df_values = df.iloc[:, 1:]
93
  # 计算统计值
94
  mean_values = df_values.mean().round(2)
@@ -102,7 +107,8 @@ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
102
  'mean': mean_values.values,
103
  'std': std_values.values,
104
  'max': max_values.values,
105
- 'min': min_values.values
 
106
  })
107
  stats_list.append(stats_df)
108
 
 
52
 
53
  return df
54
 
55
+ def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=None) -> go.Figure:
56
  """
57
  创建一个包含所有传入 DataFrame 的线图。
58
  """
59
  fig = go.Figure()
60
  for df, df_id in zip(dfs, ids):
61
+ if interval:
62
+ df = df.iloc[interval[0]:interval[1]]
63
  for i, column in enumerate(df.columns[1:]):
64
  fig.add_trace(go.Scatter(
65
  x=df[df.columns[0]],
 
84
  )
85
  return fig
86
 
87
+ def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame:
88
  """
89
  计算数据集列表的统计信息。
90
  """
91
  stats_list = []
92
 
93
  for df, id in zip(dfs, ids):
94
+ total_rows = len(df)
95
+ if interval:
96
+ df = df.iloc[interval[0]:interval[1]]
97
  df_values = df.iloc[:, 1:]
98
  # 计算统计值
99
  mean_values = df_values.mean().round(2)
 
107
  'mean': mean_values.values,
108
  'std': std_values.values,
109
  'max': max_values.values,
110
+ 'min': min_values.values,
111
+ 'total_sample_num': total_rows
112
  })
113
  stats_list.append(stats_df)
114