Spaces:
Running
Running
##################################################### | |
# Utils | |
##################################################### | |
# 本文件包含了一些用于数据处理和绘图的实用函数。 | |
import base64 | |
from io import BytesIO | |
from matplotlib import pyplot as plt | |
import pandas as pd | |
import plotly.graph_objects as go | |
import numpy as np | |
def ndarray_to_base64(ndarray): | |
""" | |
将一维np.ndarray绘图并转换为Base64编码。 | |
""" | |
# 创建绘图 | |
plt.figure(figsize=(8, 4)) | |
plt.plot(ndarray) | |
plt.title("Vector Plot") | |
plt.xlabel("Index") | |
plt.ylabel("Value") | |
plt.tight_layout() | |
# 保存图像到内存字节流 | |
buffer = BytesIO() | |
plt.savefig(buffer, format="png") | |
plt.close() | |
buffer.seek(0) | |
# 转换为Base64字符串 | |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return f"data:image/png;base64,{base64_str}" | |
def flatten_ndarray_column(df, column_name): | |
""" | |
将嵌套的np.ndarray列展平为多列。 | |
""" | |
def flatten_ndarray(ndarray): | |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O': | |
return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray]) | |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1: | |
return np.expand_dims(ndarray, axis=0) | |
return ndarray | |
flattened_data = df[column_name].apply(flatten_ndarray) | |
max_length = max(flattened_data.apply(len)) | |
for i in range(max_length): | |
df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan) | |
return df | |
def create_plot(df): | |
""" | |
创建一个包含所有列的线图。 | |
""" | |
fig = go.Figure() | |
for i, column in enumerate(df.columns[1:]): | |
fig.add_trace(go.Scatter( | |
x=df[df.columns[0]], | |
y=df[column], | |
mode='lines', | |
name=column, | |
visible=True if i == 0 else 'legendonly' | |
)) | |
# 配置图例 | |
fig.update_layout( | |
legend=dict( | |
title="Variables", | |
orientation="h", | |
yanchor="top", | |
y=-0.2, | |
xanchor="center", | |
x=0.5 | |
), | |
xaxis_title='Time', | |
yaxis_title='Values' | |
) | |
return fig | |
def create_statistic(df): | |
""" | |
计算数据集的统计信息。 | |
""" | |
df_values = df.iloc[:, 1:] | |
# 计算统计值 | |
mean_values = df_values.mean() | |
std_values = df_values.std() | |
max_values = df_values.max() | |
min_values = df_values.min() | |
# 将这些统计信息合并成一个新的DataFrame | |
stats_df = pd.DataFrame({ | |
'Variables': df_values.columns, | |
'mean': mean_values.values, | |
'std': std_values.values, | |
'max': max_values.values, | |
'min': min_values.values | |
}) | |
return stats_df | |
def clean_up_df(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
清理数据集,将嵌套的np.ndarray列展平为多列。 | |
""" | |
df['timestamp'] = df.apply(lambda row: pd.date_range( | |
start=row['start'], | |
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']), | |
freq=row['freq'] | |
).to_pydatetime().tolist(), axis=1) | |
df = flatten_ndarray_column(df, 'target') | |
# 删除原始的start和freq列 | |
df.drop(columns=['start', 'freq', 'target'], inplace=True) | |
if 'past_feat_dynamic_real' in df.columns: | |
df.drop(columns=['past_feat_dynamic_real'], inplace=True) | |
return df |