Kamarov commited on
Commit
18c459c
·
1 Parent(s): a25c8f7

修复了问题无法显示的bug

Browse files
__pycache__/comm_utils.cpython-310.pyc ADDED
Binary file (1.84 kB). View file
 
__pycache__/config.cpython-310.pyc ADDED
Binary file (1.14 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
app.py CHANGED
@@ -331,7 +331,11 @@ with gr.Blocks() as demo:
331
  ret[question_info_textbox_p2] = gr.update(value=question_info_p2)
332
  ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=interval))
333
  ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=interval))
334
- ret[user_input_box] = gr.update(value=df[COLUMN_QUESTION][0])
 
 
 
 
335
  ret[user_output_box] = gr.update(value=df[COLUMN_ANSWER][0])
336
  ret[submit_info_box] = gr.update(value="")
337
  else:
@@ -419,8 +423,8 @@ with gr.Blocks() as demo:
419
  if __name__ == "__main__":
420
 
421
  app = gr.mount_gradio_app(app, demo, path="/")
422
- # host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
423
- host = "0.0.0.0"
424
  # import subprocess
425
  # subprocess.Popen(["python", "test_server.py"])
426
  uvicorn.run(app, host=host, port=7860)
 
331
  ret[question_info_textbox_p2] = gr.update(value=question_info_p2)
332
  ret[statistics_textbox] = gr.update(value=create_statistic(df_list, id_list, interval=interval))
333
  ret[plot] = gr.update(value=create_plot(df_list, id_list, interval=interval))
334
+ if df[COLUMN_OPTION][0] is not None:
335
+ user_input_box_value = df[COLUMN_QUESTION][0] + '\n\nOptions:\n' + df[COLUMN_OPTION][0]
336
+ else:
337
+ user_input_box_value = df[COLUMN_QUESTION][0]
338
+ ret[user_input_box] = gr.update(value=user_input_box_value)
339
  ret[user_output_box] = gr.update(value=df[COLUMN_ANSWER][0])
340
  ret[submit_info_box] = gr.update(value="")
341
  else:
 
423
  if __name__ == "__main__":
424
 
425
  app = gr.mount_gradio_app(app, demo, path="/")
426
+ host = "127.0.0.1"
427
+ # host = "0.0.0.0"
428
  # import subprocess
429
  # subprocess.Popen(["python", "test_server.py"])
430
  uvicorn.run(app, host=host, port=7860)
utils.py CHANGED
@@ -57,25 +57,44 @@ def flatten_ndarray_column(df, column_name, rows_to_include, name_mapping_map:di
57
 
58
  return df
59
 
60
- def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=None) -> go.Figure:
61
  """
62
  创建一个包含所有传入 DataFrame 的线图。
63
  """
64
  fig = go.Figure()
 
65
  for df, df_id in zip(dfs, ids):
66
  if interval:
67
  df = df.iloc[interval[0]:interval[1]]
 
68
  df_normalized = df.copy()
69
- if len(df.columns) > 2:
70
- df_normalized[df.columns[1:]] = (df[df.columns[1:]] - df[df.columns[1:]].mean()) / df[df.columns[1:]].std()
71
- for i, column in enumerate(df.columns[1:]):
 
 
 
 
 
72
  fig.add_trace(go.Scatter(
73
- x=list(range(len(df[df.columns[0]]))),
74
  y=df_normalized[column],
75
  mode='lines',
76
- name=f"item_{df_id} - {column}",
77
- # visible=True if i == 0 else 'legendonly'
78
- visible=True
 
 
 
 
 
 
 
 
 
 
 
 
79
  ))
80
 
81
  # 配置图例
@@ -88,9 +107,12 @@ def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=N
88
  xanchor="center",
89
  x=0.5
90
  ),
91
- xaxis_title='Time',
 
 
92
  yaxis_title='Values'
93
  )
 
94
  return fig
95
 
96
  def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame:
 
57
 
58
  return df
59
 
60
+ def create_plot(dfs: list[pd.DataFrame], ids: list[str], interval: list[int, int] = None) -> go.Figure:
61
  """
62
  创建一个包含所有传入 DataFrame 的线图。
63
  """
64
  fig = go.Figure()
65
+
66
  for df, df_id in zip(dfs, ids):
67
  if interval:
68
  df = df.iloc[interval[0]:interval[1]]
69
+
70
  df_normalized = df.copy()
71
+ if len(df.columns) > 1:
72
+ for column in df.columns[1:]:
73
+ min_val = df[column].min()
74
+ max_val = df[column].max()
75
+ df_normalized[column] = (df[column] - min_val) / (max_val - min_val) if max_val != min_val else 0
76
+
77
+ for column in df.columns[1:]:
78
+ # 归一化数据曲线(默认可见)
79
  fig.add_trace(go.Scatter(
80
+ x=df[df.columns[0]],
81
  y=df_normalized[column],
82
  mode='lines',
83
+ name=f"Normalized {df_id} - {column}",
84
+ hovertext=list(range(len(df))),
85
+ hoverinfo="x+text+y",
86
+ visible=True # 归一化数据默认可见
87
+ ))
88
+
89
+ # 原始数据曲线(默认隐藏)
90
+ fig.add_trace(go.Scatter(
91
+ x=df[df.columns[0]],
92
+ y=df[column],
93
+ mode='lines',
94
+ name=f"Raw {df_id} - {column}",
95
+ hovertext=list(range(len(df))),
96
+ hoverinfo="x+text+y",
97
+ visible='legendonly' # 原始数据默认隐藏
98
  ))
99
 
100
  # 配置图例
 
107
  xanchor="center",
108
  x=0.5
109
  ),
110
+ xaxis=dict(
111
+ title="Timestamp",
112
+ ),
113
  yaxis_title='Values'
114
  )
115
+
116
  return fig
117
 
118
  def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame: