##################################################### # Utils ##################################################### # 本文件包含了一些用于数据处理和绘图的实用函数。 import base64 from io import BytesIO from matplotlib import pyplot as plt import pandas as pd import plotly.graph_objects as go import numpy as np from config import * def ndarray_to_base64(ndarray): """ 将一维np.ndarray绘图并转换为Base64编码。 """ # 创建绘图 plt.figure(figsize=(8, 4)) plt.plot(ndarray) plt.title("Vector Plot") plt.xlabel("Index") plt.ylabel("Value") plt.tight_layout() # 保存图像到内存字节流 buffer = BytesIO() plt.savefig(buffer, format="png") plt.close() buffer.seek(0) # 转换为Base64字符串 base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8') return f"data:image/png;base64,{base64_str}" def flatten_ndarray_column(df, column_name, rows_to_include, name_mapping_map:dict|None=None): """ 将嵌套的np.ndarray列展平为多列,并只保留指定的行。 """ def select_and_flatten(ndarray): if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O': selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)] return np.concatenate([select_and_flatten(subarray) for subarray in selected]) elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1: return np.expand_dims(ndarray, axis=0) return ndarray selected_data = df[column_name].apply(select_and_flatten) for i, index in enumerate(rows_to_include): if name_mapping_map is not None and index in name_mapping_map: df[f'{column_name}_{name_mapping_map[index]}'] = selected_data.apply(lambda x: x[i]) else: df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i]) return df def create_plot(dfs:list[pd.DataFrame], ids:list[str], interval:list[int, int]=None) -> go.Figure: """ 创建一个包含所有传入 DataFrame 的线图。 """ fig = go.Figure() for df, df_id in zip(dfs, ids): if interval: df = df.iloc[interval[0]:interval[1]] df_normalized = df.copy() if len(df.columns) > 2: df_normalized[df.columns[1:]] = (df[df.columns[1:]] - df[df.columns[1:]].mean()) / df[df.columns[1:]].std() for i, column in enumerate(df.columns[1:]): fig.add_trace(go.Scatter( x=list(range(len(df[df.columns[0]]))), y=df_normalized[column], mode='lines', name=f"item_{df_id} - {column}", # visible=True if i == 0 else 'legendonly' visible=True )) # 配置图例 fig.update_layout( legend=dict( title="Variables", orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5 ), xaxis_title='Time', yaxis_title='Values' ) return fig def create_statistic(dfs: list[pd.DataFrame], ids: list[str], interval:list[int, int]=None) -> pd.DataFrame: """ 计算数据集列表的统计信息。 """ stats_list = [] for df, id in zip(dfs, ids): total_rows = len(df) if interval: df = df.iloc[interval[0]:interval[1]] df_values = df.iloc[:, 1:] # 计算统计值 mean_values = df_values.mean() std_values = df_values.std() max_values = df_values.max() min_values = df_values.min() # 将这些统计信息合并成一个新的DataFrame stats_df = pd.DataFrame({ 'Variables': [f"{id}_{col}" for col in df_values.columns], 'mean': mean_values.values, 'std': std_values.values, 'max': max_values.values, 'min': min_values.values, 'total_sample_num': total_rows }) stats_list.append(stats_df) # 合并所有统计信息DataFrame combined_stats_df = pd.concat(stats_list, ignore_index=True) combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x) return combined_stats_df def clean_up_df(df: pd.DataFrame, rows_to_include: list[int], name_mapping_map:dict|None=None) -> pd.DataFrame: """ 清理数据集,将嵌套的np.ndarray列展平为多列。 """ rows_to_include = sorted(rows_to_include) df['timestamp'] = df.apply(lambda row: pd.date_range( start=row['start'], periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']), freq=row['freq'] ).to_pydatetime().tolist(), axis=1) df = flatten_ndarray_column(df, 'target', rows_to_include, name_mapping_map) # 删除原始的start和freq列 df.drop(columns=['start', 'freq', 'target'], inplace=True) if 'past_feat_dynamic_real' in df.columns: df.drop(columns=['past_feat_dynamic_real'], inplace=True) return df def get_question_info(df: pd.DataFrame, info_columns:list|None=None) -> pd.DataFrame: """ 从数据集中提取问题信息。 """ if info_columns is None: info_columns = [COLUMN_DOMAIN, COLUMN_SOURCE, COLUMN_QA_TYPE, COLUMN_TASK_TYPE] question_info = df[info_columns] question_info = question_info.drop_duplicates() return question_info if __name__ == '__main__': # 创建测试数据 data1 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value1': [10, 15, 20], 'Value2': [20, 25, 30] } data2 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value3': [5, 10, 15], 'Value4': [15, 20, 25] } df1 = pd.DataFrame(data1) df2 = pd.DataFrame(data2) # 转换时间列为日期时间格式 df1['Time'] = pd.to_datetime(df1['Time']) df2['Time'] = pd.to_datetime(df2['Time']) # 创建图表 fig = create_plot(df1, df2) # 显示图表 fig.show()