Spaces:
Running
Running
Liu Yiwen
commited on
Commit
·
1c081e2
1
Parent(s):
40e5362
增加了根据起止位置绘图和统计的功能
Browse files- __pycache__/utils.cpython-311.pyc +0 -0
- app.py +8 -6
- utils.py +9 -3
__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
|
|
app.py
CHANGED
@@ -267,9 +267,9 @@ with gr.Blocks() as demo:
|
|
267 |
with gr.Column(scale=3):
|
268 |
plot = gr.Plot()
|
269 |
with gr.Row():
|
270 |
-
user_input_box = gr.Textbox(
|
271 |
-
user_output_box = gr.Textbox(label="
|
272 |
-
user_io_buttom = gr.Button("发送", interactive=True)
|
273 |
# componets.append({"select_sample_box": select_sample_box,
|
274 |
# "statistics_textbox": statistics_textbox,
|
275 |
# "user_input_box": user_input_box,
|
@@ -299,11 +299,13 @@ with gr.Blocks() as demo:
|
|
299 |
ret[plot] = gr.update(value=create_plot(df_list, id_list))
|
300 |
elif dataset == 'YY26/TS_DATASETS':
|
301 |
df, max_page, info = get_page(dataset, config, split, page)
|
|
|
302 |
lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
|
|
|
303 |
# lotsa_subtargets = eval(df['target_id'][0])
|
304 |
df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
|
305 |
-
ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list))
|
306 |
-
ret[plot] = gr.update(value=create_plot(df_list, id_list))
|
307 |
ret[user_input_box] = gr.update(value=df['question'][0])
|
308 |
ret[user_output_box] = gr.update(value=df['answer'][0])
|
309 |
else:
|
@@ -380,7 +382,7 @@ with gr.Blocks() as demo:
|
|
380 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
381 |
cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
382 |
cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
383 |
-
user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
|
384 |
# select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
385 |
|
386 |
|
|
|
267 |
with gr.Column(scale=3):
|
268 |
plot = gr.Plot()
|
269 |
with gr.Row():
|
270 |
+
user_input_box = gr.Textbox(label="question", interactive=False)
|
271 |
+
user_output_box = gr.Textbox(label="answer", interactive=False)
|
272 |
+
# user_io_buttom = gr.Button("发送", interactive=True)
|
273 |
# componets.append({"select_sample_box": select_sample_box,
|
274 |
# "statistics_textbox": statistics_textbox,
|
275 |
# "user_input_box": user_input_box,
|
|
|
299 |
ret[plot] = gr.update(value=create_plot(df_list, id_list))
|
300 |
elif dataset == 'YY26/TS_DATASETS':
|
301 |
df, max_page, info = get_page(dataset, config, split, page)
|
302 |
+
# TODO: 修改lotsa_config的读取逻辑
|
303 |
lotsa_config, lotsa_split, lotsa_page = 'traffic_hourly', 'train', eval(df['ts_id'][0])
|
304 |
+
start_index, end_index = df['start_index'][0], df['end_index'][0]
|
305 |
# lotsa_subtargets = eval(df['target_id'][0])
|
306 |
df_list, id_list = process_salesforce_data('Salesforce/lotsa_data', lotsa_config, lotsa_split, lotsa_page, [1])
|
307 |
+
ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=[start_index, end_index]))
|
308 |
+
ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=[start_index, end_index]))
|
309 |
ret[user_input_box] = gr.update(value=df['question'][0])
|
310 |
ret[user_output_box] = gr.update(value=df['answer'][0])
|
311 |
else:
|
|
|
382 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
383 |
cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
384 |
cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
385 |
+
# user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
|
386 |
# select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
387 |
|
388 |
|
utils.py
CHANGED
@@ -52,12 +52,14 @@ def flatten_ndarray_column(df, column_name, rows_to_include):
|
|
52 |
|
53 |
return df
|
54 |
|
55 |
-
def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
|
56 |
"""
|
57 |
创建一个包含所有传入 DataFrame 的线图。
|
58 |
"""
|
59 |
fig = go.Figure()
|
60 |
for df, df_id in zip(dfs, ids):
|
|
|
|
|
61 |
for i, column in enumerate(df.columns[1:]):
|
62 |
fig.add_trace(go.Scatter(
|
63 |
x=df[df.columns[0]],
|
@@ -82,13 +84,16 @@ def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
|
|
82 |
)
|
83 |
return fig
|
84 |
|
85 |
-
def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
|
86 |
"""
|
87 |
计算数据集列表的统计信息。
|
88 |
"""
|
89 |
stats_list = []
|
90 |
|
91 |
for df, id in zip(dfs, ids):
|
|
|
|
|
|
|
92 |
df_values = df.iloc[:, 1:]
|
93 |
# 计算统计值
|
94 |
mean_values = df_values.mean().round(2)
|
@@ -102,7 +107,8 @@ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
|
|
102 |
'mean': mean_values.values,
|
103 |
'std': std_values.values,
|
104 |
'max': max_values.values,
|
105 |
-
'min': min_values.values
|
|
|
106 |
})
|
107 |
stats_list.append(stats_df)
|
108 |
|
|
|
52 |
|
53 |
return df
|
54 |
|
55 |
+
def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=None) -> go.Figure:
|
56 |
"""
|
57 |
创建一个包含所有传入 DataFrame 的线图。
|
58 |
"""
|
59 |
fig = go.Figure()
|
60 |
for df, df_id in zip(dfs, ids):
|
61 |
+
if interval:
|
62 |
+
df = df.iloc[interval[0]:interval[1]]
|
63 |
for i, column in enumerate(df.columns[1:]):
|
64 |
fig.add_trace(go.Scatter(
|
65 |
x=df[df.columns[0]],
|
|
|
84 |
)
|
85 |
return fig
|
86 |
|
87 |
+
def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame:
|
88 |
"""
|
89 |
计算数据集列表的统计信息。
|
90 |
"""
|
91 |
stats_list = []
|
92 |
|
93 |
for df, id in zip(dfs, ids):
|
94 |
+
total_rows = len(df)
|
95 |
+
if interval:
|
96 |
+
df = df.iloc[interval[0]:interval[1]]
|
97 |
df_values = df.iloc[:, 1:]
|
98 |
# 计算统计值
|
99 |
mean_values = df_values.mean().round(2)
|
|
|
107 |
'mean': mean_values.values,
|
108 |
'std': std_values.values,
|
109 |
'max': max_values.values,
|
110 |
+
'min': min_values.values,
|
111 |
+
'total_sample_num': total_rows
|
112 |
})
|
113 |
stats_list.append(stats_df)
|
114 |
|