Spaces:
Sleeping
Sleeping
File size: 6,845 Bytes
f3718f0 60dbd41 f3718f0 5c68062 f3718f0 5c68062 f3718f0 18c459c f3718f0 18c459c f3718f0 1c081e2 d539841 18c459c 519dd2e 18c459c f3718f0 18c459c 519dd2e f3718f0 18c459c d539841 18c459c d539841 18c459c f3718f0 18c459c f3718f0 18c459c f3718f0 1c081e2 f3718f0 1c081e2 f3718f0 15549a1 f3718f0 1c081e2 f3718f0 15549a1 f3718f0 5c68062 f3718f0 5c68062 f3718f0 a4c0034 60dbd41 a4c0034 60dbd41 f3718f0 |
|
#####################################################
# Utils
#####################################################
# 本文件包含了一些用于数据处理和绘图的实用函数。
import base64
from io import BytesIO
from matplotlib import pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from config import *
def ndarray_to_base64(ndarray):
"""
将一维np.ndarray绘图并转换为Base64编码。
"""
# 创建绘图
plt.figure(figsize=(8, 4))
plt.plot(ndarray)
plt.title("Vector Plot")
plt.xlabel("Index")
plt.ylabel("Value")
plt.tight_layout()
# 保存图像到内存字节流
buffer = BytesIO()
plt.savefig(buffer, format="png")
plt.close()
buffer.seek(0)
# 转换为Base64字符串
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/png;base64,{base64_str}"
def flatten_ndarray_column(df, column_name, rows_to_include, name_mapping_map:dict|None=None):
"""
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
"""
def select_and_flatten(ndarray):
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
return np.expand_dims(ndarray, axis=0)
return ndarray
selected_data = df[column_name].apply(select_and_flatten)
for i, index in enumerate(rows_to_include):
if name_mapping_map is not None and index in name_mapping_map:
df[f'{column_name}_{name_mapping_map[index]}'] = selected_data.apply(lambda x: x[i])
else:
df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i])
return df
def create_plot(dfs: list[pd.DataFrame], ids: list[str], interval: list[int, int] = None) -> go.Figure:
"""
创建一个包含所有传入 DataFrame 的线图。
"""
fig = go.Figure()
for df, df_id in zip(dfs, ids):
if interval:
df = df.iloc[interval[0]:interval[1]+1]
df_normalized = df.copy()
if len(df.columns) > 1:
for column in df.columns[1:]:
min_val = df[column].min()
max_val = df[column].max()
df_normalized[column] = (df[column] - min_val) / (max_val - min_val) if max_val != min_val else 0
for column in df.columns[1:]:
# 归一化数据曲线(默认可见)
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df_normalized[column],
mode='lines',
name=f"Normalized {df_id} - {column}",
hovertext=list(range(interval[0], len(df)+interval[0]+1)),
hoverinfo="x+text+y",
visible=True # 归一化数据默认可见
))
# 原始数据曲线(默认隐藏)
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df[column],
mode='lines',
name=f"Raw {df_id} - {column}",
hovertext=list(range(interval[0], len(df)+interval[0]+1)),
hoverinfo="x+text+y",
visible='legendonly' # 原始数据默认隐藏
))
# 配置图例
fig.update_layout(
legend=dict(
title="Variables",
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5
),
xaxis=dict(
title="Timestamp",
),
yaxis_title='Values'
)
return fig
def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame:
"""
计算数据集列表的统计信息。
"""
stats_list = []
for df, id in zip(dfs, ids):
total_rows = len(df)
if interval:
df = df.iloc[interval[0]:interval[1]]
df_values = df.iloc[:, 1:]
# 计算统计值
mean_values = df_values.mean()
std_values = df_values.std()
max_values = df_values.max()
min_values = df_values.min()
# 将这些统计信息合并成一个新的DataFrame
stats_df = pd.DataFrame({
'Variables': [f"{id}_{col}" for col in df_values.columns],
'mean': mean_values.values,
'std': std_values.values,
'max': max_values.values,
'min': min_values.values,
'total_sample_num': total_rows
})
stats_list.append(stats_df)
# 合并所有统计信息DataFrame
combined_stats_df = pd.concat(stats_list, ignore_index=True)
combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x)
return combined_stats_df
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int], name_mapping_map:dict|None=None) -> pd.DataFrame:
"""
清理数据集,将嵌套的np.ndarray列展平为多列。
"""
rows_to_include = sorted(rows_to_include)
df['timestamp'] = df.apply(lambda row: pd.date_range(
start=row['start'],
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
freq=row['freq']
).to_pydatetime().tolist(), axis=1)
df = flatten_ndarray_column(df, 'target', rows_to_include, name_mapping_map)
# 删除原始的start和freq列
df.drop(columns=['start', 'freq', 'target'], inplace=True)
if 'past_feat_dynamic_real' in df.columns:
df.drop(columns=['past_feat_dynamic_real'], inplace=True)
return df
def get_question_info(df: pd.DataFrame, info_columns:list|None=None) -> pd.DataFrame:
"""
从数据集中提取问题信息。
"""
if info_columns is None:
info_columns = [COLUMN_DOMAIN, COLUMN_SOURCE, COLUMN_QA_TYPE, COLUMN_TASK_TYPE]
question_info = df[info_columns]
question_info = question_info.drop_duplicates()
return question_info
if __name__ == '__main__':
# 创建测试数据
data1 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value1': [10, 15, 20],
'Value2': [20, 25, 30]
}
data2 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value3': [5, 10, 15],
'Value4': [15, 20, 25]
}
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
# 转换时间列为日期时间格式
df1['Time'] = pd.to_datetime(df1['Time'])
df2['Time'] = pd.to_datetime(df2['Time'])
# 创建图表
fig = create_plot(df1, df2)
# 显示图表
fig.show() |