File size: 984 Bytes
4570f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pandas as pd
import numpy as np

def flatten_ndarray_column(df, column_name):
    def flatten_ndarray(ndarray):
        if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
            return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
        elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
            return np.expand_dims(ndarray, axis=0)
        return ndarray

    flattened_data = df[column_name].apply(flatten_ndarray)
    max_length = max(flattened_data.apply(len))

    for i in range(max_length):
        df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)

    return df

# 示例用法
if __name__ == "__main__":
    # 创建示例 DataFrame
    data = {
        'target': [np.array([np.array([1, 2]), np.array([3, 4])]), np.array([5, 6, 7])]
    }
    df = pd.DataFrame(data)

    # 拆分 target 列中的嵌套 ndarray
    df = flatten_ndarray_column(df, 'target')
    print(df)