Spaces:
Runtime error
Runtime error
import pandas as pd | |
from zenml import step | |
from src.data_splitting import DataSplitter | |
from typing_extensions import Tuple | |
def data_splitter_step(df: pd.DataFrame, target_column: str, test_size: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]: | |
""" | |
ZenML step to split the dataframe using the DataSplitter class. | |
Parameters: | |
df : pd.DataFrame | |
The input dataframe to be split. | |
target_column : str | |
The name of the target column in the dataframe. | |
test_size : float, optional | |
The proportion of the dataset to include in the test split. Default is 0.2. | |
Returns: | |
Tuple of X_train, X_test, y_train, y_test | |
""" | |
# Initialize the DataSplitter | |
splitter = DataSplitter(df, target_column, test_size) | |
# Perform the train-test split | |
X_train, X_test, y_train, y_test = splitter.split_data() | |
return X_train, X_test, y_train, y_test | |
# Example usage | |
if __name__ == '__main__': | |
# # Sample data for testing | |
# df = pd.read_csv('data.csv') | |
# target_column = 'target' | |
# # Call the ZenML step | |
# X_train, X_test, y_train, y_test = data_splitter_step(df, target_column) | |
# # Display the results | |
# print("X_train:\n", X_train) | |
# print("X_test:\n", X_test) | |
# print("y_train:\n", y_train) | |
# print("y_test:\n", y_test) | |
pass |