alpertml's picture
Upload 16 files
fa10c3d verified
raw
history blame contribute delete
600 Bytes
from catboost import CatBoostRegressor
def train_model(
train,
model_params,
model_type,
cat_features,
valid=None,
):
X_train, y_train = train
if model_type == 'CATBOOST':
model = CatBoostRegressor(**model_params,
cat_features=cat_features)
if valid:
X_valid, y_valid = valid
eval_set=[(X_valid,y_valid)]
model.fit(X_train,y_train,
eval_set=eval_set,
verbose=200
)
return model