Spaces:
Sleeping
Sleeping
File size: 6,845 Bytes
f3718f0 60dbd41 f3718f0 5c68062 f3718f0 5c68062 f3718f0 18c459c f3718f0 18c459c f3718f0 1c081e2 d539841 18c459c 519dd2e 18c459c f3718f0 18c459c 519dd2e f3718f0 18c459c d539841 18c459c d539841 18c459c f3718f0 18c459c f3718f0 18c459c f3718f0 1c081e2 f3718f0 1c081e2 f3718f0 15549a1 f3718f0 1c081e2 f3718f0 15549a1 f3718f0 5c68062 f3718f0 5c68062 f3718f0 a4c0034 60dbd41 a4c0034 60dbd41 f3718f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
#####################################################
# Utils
#####################################################
# 本文件包含了一些用于数据处理和绘图的实用函数。
import base64
from io import BytesIO
from matplotlib import pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from config import *
def ndarray_to_base64(ndarray):
"""
将一维np.ndarray绘图并转换为Base64编码。
"""
# 创建绘图
plt.figure(figsize=(8, 4))
plt.plot(ndarray)
plt.title("Vector Plot")
plt.xlabel("Index")
plt.ylabel("Value")
plt.tight_layout()
# 保存图像到内存字节流
buffer = BytesIO()
plt.savefig(buffer, format="png")
plt.close()
buffer.seek(0)
# 转换为Base64字符串
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/png;base64,{base64_str}"
def flatten_ndarray_column(df, column_name, rows_to_include, name_mapping_map:dict|None=None):
"""
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
"""
def select_and_flatten(ndarray):
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
return np.expand_dims(ndarray, axis=0)
return ndarray
selected_data = df[column_name].apply(select_and_flatten)
for i, index in enumerate(rows_to_include):
if name_mapping_map is not None and index in name_mapping_map:
df[f'{column_name}_{name_mapping_map[index]}'] = selected_data.apply(lambda x: x[i])
else:
df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i])
return df
def create_plot(dfs: list[pd.DataFrame], ids: list[str], interval: list[int, int] = None) -> go.Figure:
"""
创建一个包含所有传入 DataFrame 的线图。
"""
fig = go.Figure()
for df, df_id in zip(dfs, ids):
if interval:
df = df.iloc[interval[0]:interval[1]+1]
df_normalized = df.copy()
if len(df.columns) > 1:
for column in df.columns[1:]:
min_val = df[column].min()
max_val = df[column].max()
df_normalized[column] = (df[column] - min_val) / (max_val - min_val) if max_val != min_val else 0
for column in df.columns[1:]:
# 归一化数据曲线(默认可见)
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df_normalized[column],
mode='lines',
name=f"Normalized {df_id} - {column}",
hovertext=list(range(interval[0], len(df)+interval[0]+1)),
hoverinfo="x+text+y",
visible=True # 归一化数据默认可见
))
# 原始数据曲线(默认隐藏)
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df[column],
mode='lines',
name=f"Raw {df_id} - {column}",
hovertext=list(range(interval[0], len(df)+interval[0]+1)),
hoverinfo="x+text+y",
visible='legendonly' # 原始数据默认隐藏
))
# 配置图例
fig.update_layout(
legend=dict(
title="Variables",
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5
),
xaxis=dict(
title="Timestamp",
),
yaxis_title='Values'
)
return fig
def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame:
"""
计算数据集列表的统计信息。
"""
stats_list = []
for df, id in zip(dfs, ids):
total_rows = len(df)
if interval:
df = df.iloc[interval[0]:interval[1]]
df_values = df.iloc[:, 1:]
# 计算统计值
mean_values = df_values.mean()
std_values = df_values.std()
max_values = df_values.max()
min_values = df_values.min()
# 将这些统计信息合并成一个新的DataFrame
stats_df = pd.DataFrame({
'Variables': [f"{id}_{col}" for col in df_values.columns],
'mean': mean_values.values,
'std': std_values.values,
'max': max_values.values,
'min': min_values.values,
'total_sample_num': total_rows
})
stats_list.append(stats_df)
# 合并所有统计信息DataFrame
combined_stats_df = pd.concat(stats_list, ignore_index=True)
combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x)
return combined_stats_df
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int], name_mapping_map:dict|None=None) -> pd.DataFrame:
"""
清理数据集,将嵌套的np.ndarray列展平为多列。
"""
rows_to_include = sorted(rows_to_include)
df['timestamp'] = df.apply(lambda row: pd.date_range(
start=row['start'],
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
freq=row['freq']
).to_pydatetime().tolist(), axis=1)
df = flatten_ndarray_column(df, 'target', rows_to_include, name_mapping_map)
# 删除原始的start和freq列
df.drop(columns=['start', 'freq', 'target'], inplace=True)
if 'past_feat_dynamic_real' in df.columns:
df.drop(columns=['past_feat_dynamic_real'], inplace=True)
return df
def get_question_info(df: pd.DataFrame, info_columns:list|None=None) -> pd.DataFrame:
"""
从数据集中提取问题信息。
"""
if info_columns is None:
info_columns = [COLUMN_DOMAIN, COLUMN_SOURCE, COLUMN_QA_TYPE, COLUMN_TASK_TYPE]
question_info = df[info_columns]
question_info = question_info.drop_duplicates()
return question_info
if __name__ == '__main__':
# 创建测试数据
data1 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value1': [10, 15, 20],
'Value2': [20, 25, 30]
}
data2 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value3': [5, 10, 15],
'Value4': [15, 20, 25]
}
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
# 转换时间列为日期时间格式
df1['Time'] = pd.to_datetime(df1['Time'])
df2['Time'] = pd.to_datetime(df2['Time'])
# 创建图表
fig = create_plot(df1, df2)
# 显示图表
fig.show() |