lotsa_explorer / flatten_ndarray.py
Liu Yiwen
lotsa viewer v0.1.0
4570f48
raw
history blame
984 Bytes
import pandas as pd
import numpy as np
def flatten_ndarray_column(df, column_name):
def flatten_ndarray(ndarray):
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
return np.expand_dims(ndarray, axis=0)
return ndarray
flattened_data = df[column_name].apply(flatten_ndarray)
max_length = max(flattened_data.apply(len))
for i in range(max_length):
df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)
return df
# 示例用法
if __name__ == "__main__":
# 创建示例 DataFrame
data = {
'target': [np.array([np.array([1, 2]), np.array([3, 4])]), np.array([5, 6, 7])]
}
df = pd.DataFrame(data)
# 拆分 target 列中的嵌套 ndarray
df = flatten_ndarray_column(df, 'target')
print(df)