|
import pandas as pd |
|
from zenml import step |
|
from src.data_splitting import DataSplitter |
|
from typing_extensions import Tuple |
|
@step |
|
def data_splitter_step(df: pd.DataFrame, target_column: str, test_size: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]: |
|
""" |
|
ZenML step to split the dataframe using the DataSplitter class. |
|
|
|
Parameters: |
|
|
|
df : pd.DataFrame |
|
The input dataframe to be split. |
|
target_column : str |
|
The name of the target column in the dataframe. |
|
test_size : float, optional |
|
The proportion of the dataset to include in the test split. Default is 0.2. |
|
|
|
Returns: |
|
|
|
Tuple of X_train, X_test, y_train, y_test |
|
""" |
|
|
|
splitter = DataSplitter(df, target_column, test_size) |
|
|
|
|
|
X_train, X_test, y_train, y_test = splitter.split_data() |
|
return X_train, X_test, y_train, y_test |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |