File size: 3,506 Bytes
b4b95a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#####################################################
# Utils
#####################################################
# 本文件包含了一些用于数据处理和绘图的实用函数。

import base64
from io import BytesIO
from matplotlib import pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import numpy as np


def ndarray_to_base64(ndarray):
    """
    将一维np.ndarray绘图并转换为Base64编码。
    """
    # 创建绘图
    plt.figure(figsize=(8, 4))
    plt.plot(ndarray)
    plt.title("Vector Plot")
    plt.xlabel("Index")
    plt.ylabel("Value")
    plt.tight_layout()
    
    # 保存图像到内存字节流
    buffer = BytesIO()
    plt.savefig(buffer, format="png")
    plt.close()
    buffer.seek(0)
    
    # 转换为Base64字符串
    base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return f"data:image/png;base64,{base64_str}"

def flatten_ndarray_column(df, column_name):
    """
    将嵌套的np.ndarray列展平为多列。
    """
    def flatten_ndarray(ndarray):
        if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
            return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
        elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
            return np.expand_dims(ndarray, axis=0)
        return ndarray

    flattened_data = df[column_name].apply(flatten_ndarray)
    max_length = max(flattened_data.apply(len))

    for i in range(max_length):
        df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)

    return df

def create_plot(df):
    """
    创建一个包含所有列的线图。
    """
    fig = go.Figure()
    for i, column in enumerate(df.columns[1:]):
        fig.add_trace(go.Scatter(
            x=df[df.columns[0]], 
            y=df[column], 
            mode='lines', 
            name=column,
            visible=True if i == 0 else 'legendonly'
        ))
    
    # 配置图例
    fig.update_layout(
        legend=dict(
            title="Variables",
            orientation="h",
            yanchor="top",
            y=-0.2,
            xanchor="center",
            x=0.5
        ),
        xaxis_title='Time',
        yaxis_title='Values'
    )
    return fig

def create_statistic(df):
    """
    计算数据集的统计信息。
    """
    df_values = df.iloc[:, 1:]
    # 计算统计值
    mean_values = df_values.mean()
    std_values = df_values.std()
    max_values = df_values.max()
    min_values = df_values.min()

    # 将这些统计信息合并成一个新的DataFrame
    stats_df = pd.DataFrame({
        'Variables': df_values.columns,
        'mean': mean_values.values,
        'std': std_values.values,
        'max': max_values.values,
        'min': min_values.values
    })
    return stats_df

def clean_up_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    清理数据集,将嵌套的np.ndarray列展平为多列。
    """
    df['timestamp'] = df.apply(lambda row: pd.date_range(
        start=row['start'], 
        periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']), 
        freq=row['freq']
    ).to_pydatetime().tolist(), axis=1)
    df = flatten_ndarray_column(df, 'target')
    # 删除原始的start和freq列
    df.drop(columns=['start', 'freq', 'target'], inplace=True)
    if 'past_feat_dynamic_real' in df.columns:
        df.drop(columns=['past_feat_dynamic_real'], inplace=True)
    return df