Surbhi commited on
Commit
9bedd59
·
1 Parent(s): 295702b

Fix n_neighbors

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -59,8 +59,19 @@ dataset_mapping = {
59
  "Fraud Detection": "datasets/fraud_detection.csv",
60
  "Customer Segmentation": "datasets/customer_segmentation.csv",
61
  "Loan Approval": "datasets/loan_approval.csv",
 
 
 
62
  "House Price Prediction": "datasets/house_price_prediction.csv",
 
63
  "Sales Forecasting": "datasets/sales_forecasting.csv",
 
 
 
 
 
 
 
64
  }
65
 
66
  dataset_path = dataset_mapping.get(problem, "datasets/spam_detection.csv")
@@ -69,7 +80,7 @@ df = pd.read_csv(dataset_path)
69
 
70
  # Model Initialization
71
  model_mapping = {
72
- "KNN": KNeighborsClassifier() if task == "Classification" else KNeighborsRegressor(),
73
  "SVM": SVC() if task == "Classification" else SVR(),
74
  "Random Forest": RandomForestClassifier() if task == "Classification" else RandomForestRegressor(),
75
  "Decision Tree": DecisionTreeClassifier() if task == "Classification" else DecisionTreeRegressor(),
 
59
  "Fraud Detection": "datasets/fraud_detection.csv",
60
  "Customer Segmentation": "datasets/customer_segmentation.csv",
61
  "Loan Approval": "datasets/loan_approval.csv",
62
+ "Churn Prediction": "datasets/churn_prediction.csv",
63
+ "Handwritten Digit Recognition": "datasets/handwritten_digit_recognition.csv",
64
+ "Sentiment Analysis": "datasets/sentiment_analysis.csv",
65
  "House Price Prediction": "datasets/house_price_prediction.csv",
66
+ "Stock Prediction": "datasets/stock_prediction.csv",
67
  "Sales Forecasting": "datasets/sales_forecasting.csv",
68
+ "Stock Market Trends": "datasets/stock_market_trends.csv",
69
+ "Energy Consumption": "datasets/energy_consumption.csv",
70
+ "Patient Survival Prediction": "datasets/patient_survival_prediction.csv",
71
+ "House Price Estimation": "datasets/house_price_estimation.csv",
72
+ "Revenue Prediction": "datasets/revenue_prediction.csv",
73
+ "Weather Forecasting": "datasets/weather_forecasting.csv",
74
+ "Traffic Flow Prediction": "datasets/traffic_flow_prediction.csv"
75
  }
76
 
77
  dataset_path = dataset_mapping.get(problem, "datasets/spam_detection.csv")
 
80
 
81
  # Model Initialization
82
  model_mapping = {
83
+ "KNN": KNeighborsClassifier(n_neighbors=min(5, len(y_train))) if task == "Classification" else KNeighborsRegressor(),
84
  "SVM": SVC() if task == "Classification" else SVR(),
85
  "Random Forest": RandomForestClassifier() if task == "Classification" else RandomForestRegressor(),
86
  "Decision Tree": DecisionTreeClassifier() if task == "Classification" else DecisionTreeRegressor(),