Spaces:
Sleeping
Sleeping
Upload 91 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Dockerfile +14 -0
- README.md +207 -13
- assets/Pipeline.png +0 -0
- assets/classificationperfomance-2.png +0 -0
- assets/classificationperfomance-3.png +0 -0
- assets/classificationperformance.png +0 -0
- assets/customer.webp +0 -0
- assets/customer_churn_image.jpg +0 -0
- assets/datadrift-2.png +0 -0
- assets/datadrift-3.png +0 -0
- assets/datadrift.png +0 -0
- assets/datapipeline-2.png +0 -0
- assets/datapipeline-3.png +0 -0
- assets/datapipeline.png +0 -0
- assets/eda.webp +0 -0
- assets/model_metrics.csv +5 -0
- backend/__init__.py +0 -0
- backend/__pycache__/__init__.cpython-312.pyc +0 -0
- backend/__pycache__/fastapi_app.cpython-312.pyc +0 -0
- backend/artifacts/Decision Tree.pkl +3 -0
- backend/artifacts/Logistic Regression.pkl +3 -0
- backend/artifacts/Random Forest.pkl +3 -0
- backend/artifacts/XGBoost.pkl +3 -0
- backend/config/__pycache__/config.cpython-312.pyc +0 -0
- backend/config/config.py +13 -0
- backend/fastapi_app.py +36 -0
- backend/reports/model_report_1.html +0 -0
- backend/reports/model_report_2.html +0 -0
- backend/reports/report.html +0 -0
- backend/train_and_evaluate.py +87 -0
- data/customer_churn_dataset-training-master.csv.zip +3 -0
- docker-compose.yml +71 -0
- frontend/EDA.py +82 -0
- frontend/__pycache__/EDA.cpython-312.pyc +0 -0
- frontend/__pycache__/about.cpython-312.pyc +0 -0
- frontend/__pycache__/analysis.cpython-312.pyc +0 -0
- frontend/__pycache__/home.cpython-312.pyc +0 -0
- frontend/__pycache__/make_prediction.cpython-312.pyc +0 -0
- frontend/__pycache__/model.cpython-312.pyc +0 -0
- frontend/__pycache__/project.cpython-312.pyc +0 -0
- frontend/about.py +24 -0
- frontend/analysis.py +239 -0
- frontend/home.py +95 -0
- frontend/main.py +31 -0
- frontend/make_prediction.py +31 -0
- frontend/model.py +85 -0
- frontend/project.py +66 -0
- frontend/reports/model_report_1.html +0 -0
- frontend/reports/model_report_2.html +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
test_data/customer_churn_dataset-training-master.csv filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dockerfile for all services
|
2 |
+
FROM python:3.12-slim
|
3 |
+
|
4 |
+
WORKDIR /app
|
5 |
+
|
6 |
+
COPY requirements.txt .
|
7 |
+
RUN pip install -r requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
# Expose port
|
12 |
+
EXPOSE 8001
|
13 |
+
EXPOSE 8501
|
14 |
+
|
README.md
CHANGED
@@ -1,13 +1,207 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# End to End Customer Churn Prediction
|
2 |
+
|
3 |
+
## :page_facing_up: Problem Statement
|
4 |
+
Customer churn is a critical issue for businesses as it directly impacts profitability and growth. This application aims to predict whether a customer will leave a service or product based on historical data and behavioral patterns.
|
5 |
+
|
6 |
+
## :dart: Objective
|
7 |
+
The main objective of this application is to develop a machine learning model that accurately predicts customer churn. By identifying at-risk customers, businesses can take proactive measures to enhance customer retention and improve overall satisfaction.
|
8 |
+
|
9 |
+
## 🛠️ Technological Stack
|
10 |
+
- **Python** - The primary programming language used for development.
|
11 |
+
- **Machine Learning** - Algorithms to analyze customer data and predict churn.
|
12 |
+
- **MLOps** - Practices for deploying and maintaining machine learning models.
|
13 |
+
- **ZenML** - A tool to create reproducible ML pipelines.
|
14 |
+
- **MLflow** - For tracking experiments and managing model lifecycle.
|
15 |
+
- **Streamlit** - A user-friendly UI framework for creating interactive web applications.
|
16 |
+
- **FastAPI** - Back-end frameworks to build APIs for model prediction server.
|
17 |
+
- **Evidently Ai** - A tool for Model Monitering and data drift detection.
|
18 |
+
|
19 |
+
## 📝 Overview
|
20 |
+
This design document outlines the development of a web application for predicting customer churn using a dataset that includes customer Age, Support Calls , Usage Frequency, Last Interaction, Tenure, Contract Length. The application will allow users to input customer data and receive predictions on churn likelihood.
|
21 |
+
|
22 |
+
|
23 |
+
## :snake: Python Requirements
|
24 |
+
Let's jumpt into python packages you need. you need run below commands in python enivironmemt
|
25 |
+
|
26 |
+
```
|
27 |
+
mkdir customer_churn_prediction
|
28 |
+
|
29 |
+
cd customer_churn_prediction
|
30 |
+
|
31 |
+
git clone https://github.com/sarathkumar1304/End-to-End-Customer-Churn-Prediction
|
32 |
+
|
33 |
+
python3 -m venv venv
|
34 |
+
```
|
35 |
+
|
36 |
+
Activate your virtual environment
|
37 |
+
|
38 |
+
```pip install requirements.txt```
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
Starting with **ZenML** an open source MLOPs Library for machine learning engineer life cycle.ZenML has react base dashboard that allows you to observe your pipeline DAG's, stacks, stack components in dashboard interface.
|
43 |
+
|
44 |
+
To access this, you need to launch the ZenML Server and Dashboard locally, we need run run below commands
|
45 |
+
|
46 |
+
|
47 |
+
```
|
48 |
+
pip install zenml["server"]
|
49 |
+
|
50 |
+
zenml up
|
51 |
+
```
|
52 |
+
|
53 |
+
After running this commands we can visualize the dashboard locally in your browser.
|
54 |
+
|
55 |
+
|
56 |
+
Then , you run the pipeline first and then follow the steps.
|
57 |
+
|
58 |
+
```
|
59 |
+
python3 run_pipeline.py
|
60 |
+
```
|
61 |
+
|
62 |
+
run_pipeline.py has all the combined python scripts to run the pipeline at one place.Now we can visualize the pipeline in your web browser.
|
63 |
+
|
64 |
+
we can the dashboard like below
|
65 |
+

|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
Before the deployment we need run some commands, to intergrate with **MLFlow** for Experiment trackering and model registry.
|
71 |
+
|
72 |
+
```
|
73 |
+
zenml integration install mlflow -y
|
74 |
+
|
75 |
+
zenml experiment-tracker register mlflow_tracker --flavor=mlflow
|
76 |
+
|
77 |
+
zenml model-deployer register mlflow_customer_churn --flavor=mlflow
|
78 |
+
|
79 |
+
zenml stack register mlflow_stack_customer_churn -a default -o default -d mlflow -e mlflow_tracker_customer_churn --set
|
80 |
+
```
|
81 |
+
|
82 |
+

|
83 |
+
|
84 |
+

|
85 |
+
|
86 |
+
After running this command we can track the experiment in mlflow dashboard locally.
|
87 |
+
|
88 |
+
```
|
89 |
+
python3 run_deployment.py
|
90 |
+
```
|
91 |
+
|
92 |
+
|
93 |
+
After running this command we successfully deploy the model in mlflow server that returns this url for further prediction http://127.0.0.1:8000/invocations for further predicition.
|
94 |
+
|
95 |
+
Run the below command to launch the streamlit ui.
|
96 |
+
|
97 |
+
```
|
98 |
+
streamlit run frontend/main.py
|
99 |
+
```
|
100 |
+
|
101 |
+
or we can run whole project using **Docker** by running the below command
|
102 |
+
|
103 |
+
```
|
104 |
+
docker-compose up --build
|
105 |
+
```
|
106 |
+
It will run the whole project at once without any error.
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
## 💪 Motivation
|
111 |
+
Understanding and addressing customer churn can significantly enhance customer loyalty and reduce marketing costs associated with acquiring new customers. This application provides insights that help businesses to implement effective retention strategies.
|
112 |
+
|
113 |
+
## 📈 Success Metrics
|
114 |
+
The project's success will be measured using the following metrics:
|
115 |
+
- Precision, Recall, and F1 Score of the churn prediction model.
|
116 |
+
- Reduction in customer churn rates observed post-implementation.
|
117 |
+
|
118 |
+
## 📑 Requirements & Constraints
|
119 |
+
### Functional Requirements
|
120 |
+
- Users can input customer data to receive churn predictions.
|
121 |
+
- Users can view performance metrics of the machine learning models.
|
122 |
+
- The model should demonstrate high accuracy in predictions.
|
123 |
+
|
124 |
+
### 🚧 Constraints
|
125 |
+
- The application is built using FastAPI as backend and Streamlit as front end , with deployment on Streamlit cloud and Containerizer using Docker
|
126 |
+
|
127 |
+
## ⚙️ Methodology
|
128 |
+
- **Problem Statement**: Develop a model to predict customer churn based on various features.
|
129 |
+
- **Data**: Utilize a dataset containing customer-related features such as demographics and service usage.
|
130 |
+
- **Techniques**: Employ data ingestion, data preprocessing, feature engineering, model selection, training, evaluation and model deployment.
|
131 |
+
- **zenml :** for creating reproducible ML pipeline.
|
132 |
+
- **MLFlow:** for experiment tracking and model registry.
|
133 |
+
- **Docker :** for containerization the whole project.
|
134 |
+
|
135 |
+
## 🏛️ Architecture
|
136 |
+
The architecture of the web application consists of:
|
137 |
+
- A **frontend** built using Streamlit for user interaction.
|
138 |
+
- A **backend** server implemented with FastAPI for handling requests and serving predictions.
|
139 |
+
- A **machine learning model** for churn prediction.
|
140 |
+
- Utilization of **Docker** for containerization.
|
141 |
+
- Hosting on **Streamlit Cloud** with a CI/CD pipeline for automated deployment.
|
142 |
+
|
143 |
+
|
144 |
+
## 🖇️ Pipeline
|
145 |
+
|
146 |
+

|
147 |
+
|
148 |
+
The MLOps (Machine Learning Operations) pipeline project is designed to create an end-to-end workflow for developing and deploying a web application that performs data preprocessing, model training, model evaluation, and prediction. The pipeline leverages Docker containers for encapsulating code, artifacts, and both the frontend and backend components of the application. The application is deployed on a Streamlit to provide a cloud hosting solution.
|
149 |
+
|
150 |
+
The pipeline follows the following sequence of steps:
|
151 |
+
|
152 |
+
**Data Ingestion**: The pipeline starts with the input data, which is sourced from a specified location. It can be in the form of a CSV file.
|
153 |
+
|
154 |
+
**Preprocessing:** The data undergoes preprocessing steps to clean, transform, and prepare it for model training. This stage handles tasks such as missing value imputation, feature scaling, and categorical variable encoding.
|
155 |
+
|
156 |
+
**Model Training:** The preprocessed data is used to train machine learning models. The pipeline supports building multiple models, allowing for experimentation and comparison of different algorithms or hyperparameters.
|
157 |
+
|
158 |
+
**Model Evaluation:** The trained models are evaluated using appropriate evaluation metrics to assess their performance. This stage helps in selecting the best-performing model for deployment.
|
159 |
+
|
160 |
+
**Docker Container:** The pipeline utilizes Docker containers to package the application code, model artifacts, and both the frontend and backend components. This containerization ensures consistent deployment across different environments and simplifies the deployment process.
|
161 |
+
|
162 |
+
**Streamlit:** The Docker container, along with the required dependencies, is deployed on a droplet Streamlit. Streamlit provides a cloud hosting solution that allows for scalability, reliability, and easy management of the web application.
|
163 |
+
|
164 |
+
**Web App:** The web application is accessible via a web browser, providing a user-friendly interface for interacting with the prediction functionality. Users can input new data and obtain predictions from the deployed model.
|
165 |
+
|
166 |
+
**Prediction:** The deployed model uses the input data from the web application to generate predictions. These predictions are then displayed to the user via the web interface.
|
167 |
+
|
168 |
+
**Evidently AI :** A tools for model monitering and data drifting detection when new data comes in.
|
169 |
+
|
170 |
+
**CI/CD Pipeline:** The pipeline is automated using GitHub Actions, which allows for continuous integration and deployment of the application. This automation ensures that the application is always up-to-date and provides a consistent experience for users.
|
171 |
+
|
172 |
+
### 🕹️ Streamlit App
|
173 |
+
|
174 |
+
There is a live demo of this project using Streamlit which you can find [here](https://end-to-end-customer-churn-prediction-9en8mwgk9xqgy8z7envfmm.streamlit.app/)
|
175 |
+
|
176 |
+
## Video
|
177 |
+
|
178 |
+
## Project Video
|
179 |
+
|
180 |
+
Check out the demo of my project on YouTube: [Watch Video](https://youtu.be/Nj-ICiTPJYA?si=EUfhNkFQ3PF47V25) below
|
181 |
+
|
182 |
+
[](https://youtu.be/Nj-ICiTPJYA?si=EUfhNkFQ3PF47V25)
|
183 |
+
|
184 |
+
<if
|
185 |
+
|
186 |
+
|
187 |
+
## Data Report
|
188 |
+
|
189 |
+
### Data Drift Detection using Evidently ai
|
190 |
+
|
191 |
+

|
192 |
+
|
193 |
+

|
194 |
+
|
195 |
+
|
196 |
+

|
197 |
+
|
198 |
+
|
199 |
+
## Classification Perfomance using Evidently AI
|
200 |
+
|
201 |
+

|
202 |
+
|
203 |
+
|
204 |
+

|
205 |
+
|
206 |
+

|
207 |
+
|
assets/Pipeline.png
ADDED
![]() |
assets/classificationperfomance-2.png
ADDED
![]() |
assets/classificationperfomance-3.png
ADDED
![]() |
assets/classificationperformance.png
ADDED
![]() |
assets/customer.webp
ADDED
![]() |
assets/customer_churn_image.jpg
ADDED
![]() |
assets/datadrift-2.png
ADDED
![]() |
assets/datadrift-3.png
ADDED
![]() |
assets/datadrift.png
ADDED
![]() |
assets/datapipeline-2.png
ADDED
![]() |
assets/datapipeline-3.png
ADDED
![]() |
assets/datapipeline.png
ADDED
![]() |
assets/eda.webp
ADDED
![]() |
assets/model_metrics.csv
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model,Accuracy,Precision,Recall,F1 Score
|
2 |
+
Logistic Regression,0.8507372400756144,0.8520515900889458,0.8507372400756144,0.8510765294974902
|
3 |
+
Random Forest,0.9996219281663516,0.999622210747845,0.9996219281663516,0.9996219462806213
|
4 |
+
Decision Tree,0.9996975425330813,0.9996975550375765,0.9996975425330813,0.9996975450588826
|
5 |
+
XGBoost,0.9998714555765595,0.9998714937956894,0.9998714555765595,0.9998714578564856
|
backend/__init__.py
ADDED
File without changes
|
backend/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (121 Bytes). View file
|
|
backend/__pycache__/fastapi_app.cpython-312.pyc
ADDED
Binary file (1.7 kB). View file
|
|
backend/artifacts/Decision Tree.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:307adb46df9e5c3ed7530f66ad9094e55dfcac1bf66da4eeeddc044946525be5
|
3 |
+
size 33385
|
backend/artifacts/Logistic Regression.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9bec9e47636ed96dd8f8ef3d3a53ad908011747336d05e595c5ccff113eb7327
|
3 |
+
size 1391
|
backend/artifacts/Random Forest.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb77c5d8a4992429c7d7e6979cc426532bdb986ff7142e6d7eceafe23dc18406
|
3 |
+
size 22647065
|
backend/artifacts/XGBoost.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:603556769fca526993815db6f662c2760fe9b9d4a2e12dbe88637ca5e4a563ee
|
3 |
+
size 213872
|
backend/config/__pycache__/config.cpython-312.pyc
ADDED
Binary file (683 Bytes). View file
|
|
backend/config/config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
|
4 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
5 |
+
|
6 |
+
|
7 |
+
REPORTS_DIR = BASE_DIR / "reports"
|
8 |
+
ARTIFACTS_DIR = BASE_DIR / "artifacts"
|
9 |
+
DATA_DIR = BASE_DIR / "test_data"
|
10 |
+
|
11 |
+
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
12 |
+
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
13 |
+
|
backend/fastapi_app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
import requests
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
app = FastAPI()
|
7 |
+
|
8 |
+
# Define the structure of the incoming data
|
9 |
+
class InputData(BaseModel):
|
10 |
+
dataframe_records: list[dict] # List of dictionaries (like rows in a DataFrame)
|
11 |
+
|
12 |
+
# Define endpoint to receive data and make a prediction
|
13 |
+
@app.post("/predict")
|
14 |
+
async def make_prediction(input_data: InputData):
|
15 |
+
# URL for MLflow's prediction server
|
16 |
+
mlflow_url = "http://127.0.0.1:8000/invocations"
|
17 |
+
headers = {"Content-Type": "application/json"}
|
18 |
+
|
19 |
+
# Prepare the JSON data to send to MLflow
|
20 |
+
json_data = {
|
21 |
+
"dataframe_records": input_data.dataframe_records
|
22 |
+
}
|
23 |
+
|
24 |
+
try:
|
25 |
+
# Send data to MLflow and get prediction
|
26 |
+
response = requests.post(mlflow_url, headers=headers, json=json_data)
|
27 |
+
response.raise_for_status() # Raise an error for a failed request
|
28 |
+
return response.json() # Return MLflow's prediction result
|
29 |
+
|
30 |
+
except requests.exceptions.HTTPError as err:
|
31 |
+
raise HTTPException(status_code=response.status_code, detail=str(err))
|
32 |
+
except requests.exceptions.RequestException as e:
|
33 |
+
raise HTTPException(status_code=500, detail=str(e))
|
34 |
+
|
35 |
+
|
36 |
+
# uvicorn backend.fastapi_app:app --port 8001
|
backend/reports/model_report_1.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
backend/reports/model_report_2.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
backend/reports/report.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
backend/train_and_evaluate.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sklearn.model_selection import train_test_split
|
3 |
+
from sklearn.ensemble import RandomForestClassifier
|
4 |
+
from sklearn.linear_model import LogisticRegression
|
5 |
+
from sklearn.tree import DecisionTreeClassifier
|
6 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
7 |
+
from xgboost import XGBClassifier
|
8 |
+
import logging
|
9 |
+
import joblib
|
10 |
+
from config.config import REPORTS_DIR,ARTIFACTS_DIR
|
11 |
+
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
filename='/home/sarath_kumar/customer_chrun_prediction/training_log.log',
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
18 |
+
)
|
19 |
+
|
20 |
+
logging.info("Starting training script...")
|
21 |
+
|
22 |
+
try:
|
23 |
+
|
24 |
+
data = pd.read_csv("/home/sarath_kumar/customer_chrun_prediction/processed_data/processed_data.csv")
|
25 |
+
logging.info("Dataset loaded successfully.")
|
26 |
+
|
27 |
+
|
28 |
+
X = data.drop('Churn', axis=1)
|
29 |
+
y = data['Churn']
|
30 |
+
logging.info("Data split into features and target.")
|
31 |
+
|
32 |
+
|
33 |
+
models = {
|
34 |
+
"Logistic Regression": LogisticRegression(max_iter=500,solver='saga'),
|
35 |
+
"Random Forest": RandomForestClassifier(),
|
36 |
+
"Decision Tree": DecisionTreeClassifier(),
|
37 |
+
"XGBoost": XGBClassifier(),
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
metrics_list = []
|
42 |
+
|
43 |
+
|
44 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
|
45 |
+
logging.info("Data split into training and testing sets.")
|
46 |
+
|
47 |
+
for model_name, model in models.items():
|
48 |
+
logging.info(f"Training {model_name}...")
|
49 |
+
model.fit(X_train, y_train)
|
50 |
+
logging.info(f"{model_name} training completed.")
|
51 |
+
|
52 |
+
y_pred = model.predict(X_test)
|
53 |
+
logging.info(f"{model_name} prediction completed.")
|
54 |
+
|
55 |
+
|
56 |
+
accuracy = accuracy_score(y_test, y_pred)
|
57 |
+
precision = precision_score(y_test, y_pred, average='weighted')
|
58 |
+
recall = recall_score(y_test, y_pred, average='weighted')
|
59 |
+
f1 = f1_score(y_test, y_pred, average='weighted')
|
60 |
+
|
61 |
+
logging.info(f"{model_name} evaluation metrics calculated.")
|
62 |
+
|
63 |
+
|
64 |
+
metrics_list.append({
|
65 |
+
"Model": model_name,
|
66 |
+
"Accuracy": accuracy,
|
67 |
+
"Precision": precision,
|
68 |
+
"Recall": recall,
|
69 |
+
"F1 Score": f1
|
70 |
+
})
|
71 |
+
|
72 |
+
metrics_df = pd.DataFrame(metrics_list)
|
73 |
+
logging.info("Metrics DataFrame created.")
|
74 |
+
metrics_df.to_csv(REPORTS_DIR / "model_metrics.csv", index=False)
|
75 |
+
logging.info("Metrics saved to CSV successfully.")
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
for model_name, model in models.items():
|
80 |
+
joblib.dump(model, ARTIFACTS_DIR/ f"{model_name}.pkl")
|
81 |
+
logging.info(f"{model_name} saved to file.")
|
82 |
+
|
83 |
+
logging.info("Training script completed successfully.")
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"An error occurred: {e}")
|
87 |
+
raise
|
data/customer_churn_dataset-training-master.csv.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60fde199f3fb3965060175d1711335cbc30cc8cce2c09394e031573997bc7ee1
|
3 |
+
size 6115326
|
docker-compose.yml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
services:
|
4 |
+
# Step 1: Run Deployment
|
5 |
+
deployment:
|
6 |
+
build:
|
7 |
+
context: .
|
8 |
+
dockerfile: Dockerfile
|
9 |
+
command: >
|
10 |
+
bash -c "
|
11 |
+
zenml init &&
|
12 |
+
zenml integration install mlflow -y &&
|
13 |
+
zenml experiment-tracker register mlflow_tracker_customer_churn_new --flavor=mlflow &&
|
14 |
+
zenml model-deployer register mlflow_customer_churn_new --flavor=mlflow &&
|
15 |
+
zenml stack register mlflow_stack_customer_churn_new -a default -o default -d mlflow -e mlflow_tracker_customer_churn_new --set &&
|
16 |
+
zenml stack set mlflow_stack_customer_churn_new &&
|
17 |
+
python3 run_pipeline.py&&
|
18 |
+
python3 run_deployment.py
|
19 |
+
"
|
20 |
+
volumes:
|
21 |
+
- .:/app
|
22 |
+
working_dir: /app
|
23 |
+
restart: on-failure
|
24 |
+
healthcheck:
|
25 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"] # Adjust URL for deployment health check
|
26 |
+
interval: 10s
|
27 |
+
retries: 3
|
28 |
+
start_period: 5s
|
29 |
+
timeout: 5s
|
30 |
+
|
31 |
+
# Step 2: Run FastAPI service after Deployment is completed
|
32 |
+
fastapi_service:
|
33 |
+
build:
|
34 |
+
context: .
|
35 |
+
dockerfile: Dockerfile
|
36 |
+
command: ["uvicorn", "backend.fastapi_app:app", "--host", "0.0.0.0", "--port", "8001"]
|
37 |
+
depends_on:
|
38 |
+
- deployment
|
39 |
+
volumes:
|
40 |
+
- .:/app
|
41 |
+
working_dir: /app
|
42 |
+
ports:
|
43 |
+
- "8001:8001"
|
44 |
+
restart: on-failure
|
45 |
+
healthcheck:
|
46 |
+
test: ["CMD", "curl", "-f", "http://localhost:8001/health"] # Adjust URL for FastAPI health check
|
47 |
+
interval: 10s
|
48 |
+
retries: 3
|
49 |
+
start_period: 5s
|
50 |
+
timeout: 5s
|
51 |
+
|
52 |
+
# Step 3: Run Streamlit UI after FastAPI service is up
|
53 |
+
streamlit:
|
54 |
+
build:
|
55 |
+
context: .
|
56 |
+
dockerfile: Dockerfile
|
57 |
+
command: ["streamlit", "run", "frontend/main.py"]
|
58 |
+
depends_on:
|
59 |
+
- fastapi_service
|
60 |
+
volumes:
|
61 |
+
- .:/app
|
62 |
+
working_dir: /app
|
63 |
+
ports:
|
64 |
+
- "8501:8501"
|
65 |
+
restart: on-failure
|
66 |
+
healthcheck:
|
67 |
+
test: ["CMD", "curl", "-f", "http://localhost:8501/health"] # Adjust URL for Streamlit health check
|
68 |
+
interval: 10s
|
69 |
+
retries: 3
|
70 |
+
start_period: 5s
|
71 |
+
timeout: 5s
|
frontend/EDA.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from analysis import univariate_analysis, BivariateAnalysis, multivariate_analysis
|
4 |
+
|
5 |
+
def eda():
|
6 |
+
st.image("/home/sarath_kumar/customer_churn_predict/assets/eda.webp",width=300)
|
7 |
+
st.title("Exploratory Data Analysis")
|
8 |
+
|
9 |
+
|
10 |
+
data = pd.read_csv("extracted/customer_churn_dataset-training-master.csv")
|
11 |
+
data.drop("CustomerID",axis = 1,inplace = True)
|
12 |
+
|
13 |
+
data.dropna(axis=0,inplace = True,how = "all")
|
14 |
+
st.header("Dataset Overview")
|
15 |
+
st.dataframe(data.head())
|
16 |
+
|
17 |
+
|
18 |
+
st.subheader("Select Analysis Type")
|
19 |
+
analysis_type = st.selectbox(
|
20 |
+
"Select Analysis Type",
|
21 |
+
["Univariate Analysis", "Bivariate Analysis", "Multivariate Analysis"]
|
22 |
+
)
|
23 |
+
|
24 |
+
if analysis_type == "Univariate Analysis":
|
25 |
+
st.subheader("Univariate Analysis")
|
26 |
+
column = st.selectbox("Select a column for univariate analysis", data.columns, key="uni")
|
27 |
+
plot_type = st.selectbox("Select plot type", ["Histogram", "Boxplot", "Pie Chart", "Bar Plot"], key="uni_plot")
|
28 |
+
|
29 |
+
|
30 |
+
if st.button("Generate Univariate Plot", key="uni_button"):
|
31 |
+
if column:
|
32 |
+
univariate_analysis(data, column, plot_type)
|
33 |
+
else:
|
34 |
+
st.warning("Please select a column for analysis.")
|
35 |
+
|
36 |
+
elif analysis_type == "Bivariate Analysis":
|
37 |
+
st.subheader("Bivariate Analysis")
|
38 |
+
column_x = st.selectbox("Select X-axis column", data.columns, key="bi_x")
|
39 |
+
column_y = st.selectbox("Select Y-axis column", data.columns, key="bi_y")
|
40 |
+
plot_type = st.selectbox("Select plot type", ["Scatter Plot", "Bar Plot", "Boxplot"], key="bi_plot")
|
41 |
+
|
42 |
+
if st.button("Generate Bivariate Plot", key="bi_button"):
|
43 |
+
|
44 |
+
if pd.api.types.is_numeric_dtype(data[column_x]) and pd.api.types.is_numeric_dtype(data[column_y]):
|
45 |
+
analysis = BivariateAnalysis()
|
46 |
+
analysis.numerical_vs_numerical(data, column_x, column_y, plot_type)
|
47 |
+
elif pd.api.types.is_categorical_dtype(data[column_x]) and pd.api.types.is_categorical_dtype(data[column_y]):
|
48 |
+
analysis = BivariateAnalysis()
|
49 |
+
analysis.numerical_vs_categorical(data, column_x, column_y, plot_type)
|
50 |
+
elif pd.api.types.is_numeric_dtype(data[column_x]) and pd.api.types.is_categorical_dtype(data[column_y]):
|
51 |
+
analysis = BivariateAnalysis()
|
52 |
+
analysis.numerical_vs_categorical(data, column_x, column_y, plot_type)
|
53 |
+
elif pd.api.types.is_categorical_dtype(data[column_x]) and pd.api.types.is_numeric_dtype(data[column_y]):
|
54 |
+
analysis = BivariateAnalysis()
|
55 |
+
analysis.numerical_vs_categorical(data, column_x, column_y, plot_type)
|
56 |
+
elif pd.api.types.is_numeric_dtype(data[column_x]) and pd.api.types.is_numeric_dtype(data[column_y]):
|
57 |
+
analysis = BivariateAnalysis()
|
58 |
+
analysis.numerical_vs_numerical(data, column_x, column_y, plot_type)
|
59 |
+
else:
|
60 |
+
st.warning("Please select numerical columns for analysis. Only numerical data types are supported for this plot.")
|
61 |
+
|
62 |
+
elif analysis_type == "Multivariate Analysis":
|
63 |
+
data = pd.read_csv("/home/sarath_kumar/customer_chrun_prediction/processed_data/processed_data.csv")
|
64 |
+
st.subheader("Multivariate Analysis")
|
65 |
+
columns = st.multiselect("Select columns for multivariate analysis", data.columns)
|
66 |
+
|
67 |
+
# Add an option for users to select the type of plot
|
68 |
+
plot_type = st.selectbox(
|
69 |
+
"Select plot type for multivariate analysis",
|
70 |
+
["Correlation Heatmap", "Scatter Matrix"]
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
if st.button("Generate Multivariate Plot", key="multi_button"):
|
75 |
+
if columns:
|
76 |
+
|
77 |
+
multivariate_analysis(data, columns, plot_type)
|
78 |
+
else:
|
79 |
+
st.warning("Please select columns for multivariate analysis.")
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
eda()
|
frontend/__pycache__/EDA.cpython-312.pyc
ADDED
Binary file (5.39 kB). View file
|
|
frontend/__pycache__/about.cpython-312.pyc
ADDED
Binary file (1.76 kB). View file
|
|
frontend/__pycache__/analysis.cpython-312.pyc
ADDED
Binary file (11.2 kB). View file
|
|
frontend/__pycache__/home.cpython-312.pyc
ADDED
Binary file (5.74 kB). View file
|
|
frontend/__pycache__/make_prediction.cpython-312.pyc
ADDED
Binary file (1.33 kB). View file
|
|
frontend/__pycache__/model.cpython-312.pyc
ADDED
Binary file (14.1 kB). View file
|
|
frontend/__pycache__/project.cpython-312.pyc
ADDED
Binary file (3.31 kB). View file
|
|
frontend/about.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def about_me():
|
5 |
+
st.title("About Me 🙋♂️ ")
|
6 |
+
st.write("""
|
7 |
+
Hello! 👋 I'm **R. Sarath Kumar**, and I'm thrilled to have you here! I’m a dedicated and passionate professional in **Data Science** and **Machine Learning** 🤖.
|
8 |
+
With a strong foundation in statistics, machine learning, and MLOps, I love transforming data into valuable insights and building predictive models that solve real-world problems. My work spans across multiple domains, and I’m always excited to explore new tools and techniques to make data-driven decisions more effective and impactful.
|
9 |
+
|
10 |
+
Feel free to browse through my projects, where you’ll find some of the most exciting applications of AI and machine learning, and don't hesitate to reach out if you’d like to connect or discuss potential collaborations!
|
11 |
+
""")
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
st.subheader("Contact Information")
|
16 |
+
|
17 |
+
|
18 |
+
st.write("🔗LinkedIn: [LinkedIn](https://www.linkedin.com/in/r-sarath-kumar-666084257)")
|
19 |
+
st.write("🔗Github:[Github](https://www.github.com/sarathkumar1304)")
|
20 |
+
|
21 |
+
|
22 |
+
st.write("📧 Email: [[email protected]](mailto:[email protected])")
|
23 |
+
|
24 |
+
st.write("📞 Phone: 7780651312")
|
frontend/analysis.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.express as px
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import itertools
|
6 |
+
from scipy.stats import pearsonr, pointbiserialr
|
7 |
+
from sklearn.ensemble import RandomForestClassifier
|
8 |
+
import seaborn as sns
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
# def univariate_analysis(data, column, plot_type):
|
12 |
+
# if plot_type == "Histogram":
|
13 |
+
# if data[column].dtype=="int64" or data[column].dtype=="float64":
|
14 |
+
# fig = px.histogram(data, x=column, title=f'Histogram of {column}')
|
15 |
+
# st.plotly_chart(fig)
|
16 |
+
# else:
|
17 |
+
# st.warning("Histograms are only suitable for numerical columns.")
|
18 |
+
|
19 |
+
# elif plot_type == "Boxplot":
|
20 |
+
# if data[column].dtype=="int64" or data[column].dtype=="float64":
|
21 |
+
# fig = px.box(data, y=column, title=f'Boxplot of {column}')
|
22 |
+
# st.plotly_chart(fig)
|
23 |
+
# else:
|
24 |
+
# st.warning("Boxplots are only suitable for numerical columns.")
|
25 |
+
# elif plot_type == "Pie Chart":
|
26 |
+
# if data[column].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column]):
|
27 |
+
# fig = px.pie(data, names=column, title=f'Pie Chart of {column}')
|
28 |
+
# st.plotly_chart(fig)
|
29 |
+
# else:
|
30 |
+
# st.warning("Pie charts are only suitable for categorical columns.")
|
31 |
+
# elif plot_type == "Bar Plot":
|
32 |
+
# if data[column].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column]):
|
33 |
+
# fig = px.bar(data[column].value_counts().reset_index(), x='index', y=column, title=f'Bar Plot of {column}')
|
34 |
+
# st.plotly_chart(fig)
|
35 |
+
# else:
|
36 |
+
# st.warning("Bar plots are only suitable for categorical columns.")
|
37 |
+
|
38 |
+
import pandas as pd
|
39 |
+
import plotly.express as px
|
40 |
+
import streamlit as st
|
41 |
+
|
42 |
+
def univariate_analysis(data, column, plot_type):
|
43 |
+
if plot_type == "Histogram":
|
44 |
+
if data[column].dtype == "int64" or data[column].dtype == "float64":
|
45 |
+
fig = px.histogram(data, x=column, title=f'Histogram of {column}')
|
46 |
+
st.plotly_chart(fig)
|
47 |
+
else:
|
48 |
+
st.warning("Histograms are only suitable for numerical columns.")
|
49 |
+
|
50 |
+
elif plot_type == "Boxplot":
|
51 |
+
if data[column].dtype == "int64" or data[column].dtype == "float64":
|
52 |
+
fig = px.box(data, y=column, title=f'Boxplot of {column}')
|
53 |
+
st.plotly_chart(fig)
|
54 |
+
else:
|
55 |
+
st.warning("Boxplots are only suitable for numerical columns.")
|
56 |
+
|
57 |
+
elif plot_type == "Pie Chart":
|
58 |
+
if data[column].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column]):
|
59 |
+
fig = px.pie(data, names=column, title=f'Pie Chart of {column}')
|
60 |
+
st.plotly_chart(fig)
|
61 |
+
else:
|
62 |
+
st.warning("Pie charts are only suitable for categorical columns.")
|
63 |
+
|
64 |
+
elif plot_type == "Bar Plot":
|
65 |
+
if data[column].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column]):
|
66 |
+
# Get value counts and reset index, then rename columns for Plotly
|
67 |
+
data_count = data[column].value_counts().reset_index()
|
68 |
+
data_count.columns = ['index', column] # Renaming columns
|
69 |
+
|
70 |
+
fig = px.bar(data_count, x='index', y=column, title=f'Bar Plot of {column}')
|
71 |
+
st.plotly_chart(fig)
|
72 |
+
else:
|
73 |
+
st.warning("Bar plots are only suitable for categorical columns.")
|
74 |
+
|
75 |
+
|
76 |
+
# def multivariate_analysis(data, columns):
|
77 |
+
# fig = px.scatter_matrix(data, dimensions=columns, title=f'Multivariate Analysis')
|
78 |
+
# st.plotly_chart(fig)
|
79 |
+
|
80 |
+
|
81 |
+
def multivariate_analysis(data, columns, plot_type):
|
82 |
+
if plot_type == "Correlation Heatmap":
|
83 |
+
st.subheader("Correlation Heatmap")
|
84 |
+
if len(columns) > 1:
|
85 |
+
# Compute the correlation matrix
|
86 |
+
correlation_matrix = data[columns].corr()
|
87 |
+
|
88 |
+
# Create a heatmap using Seaborn and Matplotlib
|
89 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
90 |
+
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", vmin=-1, vmax=1, ax=ax)
|
91 |
+
st.pyplot(fig)
|
92 |
+
else:
|
93 |
+
st.warning("Please select at least two columns for a correlation heatmap.")
|
94 |
+
|
95 |
+
elif plot_type == "Scatter Matrix":
|
96 |
+
st.subheader("Scatter Matrix Plot")
|
97 |
+
if len(columns) > 1:
|
98 |
+
fig = px.scatter_matrix(data, dimensions=columns, title='Scatter Matrix Plot')
|
99 |
+
st.plotly_chart(fig)
|
100 |
+
else:
|
101 |
+
st.warning("Please select at least two columns for a scatter plot matrix.")
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
class BivariateAnalysis:
|
106 |
+
def numerical_vs_numerical(self, data, column_x, column_y, plot_type):
|
107 |
+
plt.figure(figsize=(10, 6))
|
108 |
+
if plot_type == "Scatter Plot":
|
109 |
+
if data[column_x].dtype == 'int64' or data[column_x].dtype == 'float64' and data[column_y].dtype == 'int64' or data[column_y].dtype == 'float64':
|
110 |
+
sns.scatterplot(data=data, x=column_x, y=column_y)
|
111 |
+
plt.title(f'Scatter Plot of {column_x} vs {column_y}')
|
112 |
+
else:
|
113 |
+
st.warning("Scatter plots are only suitable for numerical columns.")
|
114 |
+
|
115 |
+
elif plot_type == "Bar Plot":
|
116 |
+
if data[column_x].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column_x]) and data[column_y].dtype == 'object' or pd.api.types.is_categorical_dtype(data[column_y]):
|
117 |
+
sns.barplot(data=data, x=column_x, y=column_y)
|
118 |
+
plt.title(f'Bar Plot of {column_x} vs {column_y}')
|
119 |
+
else:
|
120 |
+
st.warning("Bar plots are only suitable for categorical columns.")
|
121 |
+
elif plot_type == "Boxplot":
|
122 |
+
if data[column_x].dtype == 'int64' or data[column_x].dtype == 'float64' and data[column_y].dtype == 'int64' or data[column_y].dtype == 'float64':
|
123 |
+
sns.boxplot(data=data, x=column_x, y=column_y)
|
124 |
+
plt.title(f'Boxplot of {column_x} vs {column_y}')
|
125 |
+
else:
|
126 |
+
st.warning("Boxplots are only suitable for numerical columns.")
|
127 |
+
st.pyplot(plt.gcf())
|
128 |
+
plt.clf()
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
def numerical_vs_categorical(df, categorical_feature='Churn'):
|
133 |
+
numerical_features = df.select_dtypes(include=[float, int]).columns
|
134 |
+
if df[categorical_feature].nunique() != 2:
|
135 |
+
print(f"The categorical feature '{categorical_feature}' is not binary. Skipping correlation calculation.")
|
136 |
+
for feature in numerical_features:
|
137 |
+
fig = px.box(
|
138 |
+
df, x=categorical_feature, y=feature, color=categorical_feature,
|
139 |
+
title=f"Box Plot of {feature} by {categorical_feature}",
|
140 |
+
labels={categorical_feature: categorical_feature, feature: feature}
|
141 |
+
)
|
142 |
+
fig.update_layout(
|
143 |
+
xaxis_title=categorical_feature,
|
144 |
+
yaxis_title=feature,
|
145 |
+
hovermode="x unified"
|
146 |
+
)
|
147 |
+
fig.show()
|
148 |
+
return
|
149 |
+
|
150 |
+
df[categorical_feature] = pd.factorize(df[categorical_feature])[0]
|
151 |
+
for feature in numerical_features:
|
152 |
+
valid_data = df[[feature, categorical_feature]].dropna()
|
153 |
+
valid_data[feature] = pd.to_numeric(valid_data[feature], errors='coerce').dropna()
|
154 |
+
correlation, _ = pointbiserialr(valid_data[feature], valid_data[categorical_feature])
|
155 |
+
title = f"Box Plot of {feature} by {categorical_feature} (Correlation: {correlation:.2f})"
|
156 |
+
fig = px.box(
|
157 |
+
valid_data, x=categorical_feature, y=feature, color=categorical_feature,
|
158 |
+
title=title,
|
159 |
+
labels={categorical_feature: categorical_feature, feature: feature}
|
160 |
+
)
|
161 |
+
fig.update_layout(
|
162 |
+
xaxis_title=categorical_feature,
|
163 |
+
yaxis_title=feature,
|
164 |
+
hovermode="x unified"
|
165 |
+
)
|
166 |
+
fig.show()
|
167 |
+
|
168 |
+
|
169 |
+
def numerical_vs_target(df, target='Churn'):
|
170 |
+
numerical_features = df.select_dtypes(include=[float, int]).columns
|
171 |
+
for feature in numerical_features:
|
172 |
+
fig = px.box(
|
173 |
+
df,
|
174 |
+
x=target,
|
175 |
+
y=feature,
|
176 |
+
color=target,
|
177 |
+
title=f"Distribution of {feature} by {target} Status",
|
178 |
+
labels={target: f"{target} Status", feature: feature}
|
179 |
+
)
|
180 |
+
fig.update_layout(
|
181 |
+
xaxis_title=f"{target} Status",
|
182 |
+
yaxis_title=feature,
|
183 |
+
legend_title=target,
|
184 |
+
hovermode="x unified"
|
185 |
+
)
|
186 |
+
fig.show()
|
187 |
+
|
188 |
+
|
189 |
+
def categorical_vs_target(df, target='Churn'):
|
190 |
+
categorical_features = df.select_dtypes(include=[object]).columns
|
191 |
+
for feature in categorical_features:
|
192 |
+
crosstab_data = pd.crosstab(df[feature], df[target])
|
193 |
+
crosstab_df = crosstab_data.reset_index().melt(id_vars=feature, value_name="Count")
|
194 |
+
fig = px.bar(
|
195 |
+
crosstab_df,
|
196 |
+
x=feature,
|
197 |
+
y="Count",
|
198 |
+
color=target,
|
199 |
+
title=f"{target} by {feature}",
|
200 |
+
labels={feature: feature, "Count": "Count", target: f"{target} Status"},
|
201 |
+
text="Count",
|
202 |
+
barmode="group"
|
203 |
+
)
|
204 |
+
fig.update_layout(
|
205 |
+
xaxis_title=feature,
|
206 |
+
yaxis_title="Count",
|
207 |
+
legend_title=target,
|
208 |
+
hovermode="x unified"
|
209 |
+
)
|
210 |
+
fig.show()
|
211 |
+
|
212 |
+
def feature_importance(df, target_column):
|
213 |
+
X = df.drop(columns=[target_column])
|
214 |
+
y = df[target_column]
|
215 |
+
|
216 |
+
model = RandomForestClassifier(random_state=0)
|
217 |
+
model.fit(X.select_dtypes(include=[np.number]), y)
|
218 |
+
|
219 |
+
importance_df = pd.DataFrame({
|
220 |
+
"Feature": X.select_dtypes(include=[np.number]).columns,
|
221 |
+
"Importance": model.feature_importances_
|
222 |
+
}).sort_values(by="Importance", ascending=True)
|
223 |
+
|
224 |
+
fig_importance = px.bar(
|
225 |
+
importance_df,
|
226 |
+
x="Importance",
|
227 |
+
y="Feature",
|
228 |
+
title="Feature Importance",
|
229 |
+
orientation="h",
|
230 |
+
color="Importance",
|
231 |
+
color_continuous_scale="Viridis",
|
232 |
+
)
|
233 |
+
fig_importance.update_layout(
|
234 |
+
title_font=dict(size=20),
|
235 |
+
xaxis_title="Importance Score",
|
236 |
+
yaxis_title="Features",
|
237 |
+
font=dict(size=12),
|
238 |
+
)
|
239 |
+
fig_importance.show()
|
frontend/home.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def home_page():
|
4 |
+
st.image("assets/customer.webp")
|
5 |
+
|
6 |
+
st.title("Customer Churn Prediction Application 🌐 ")
|
7 |
+
|
8 |
+
st.header("📝 Problem Statement")
|
9 |
+
st.write("""
|
10 |
+
Customer churn is a critical issue for businesses as it directly impacts profitability and growth. This application aims to predict whether a customer will leave a service or product based on historical data and behavioral patterns.
|
11 |
+
""")
|
12 |
+
|
13 |
+
|
14 |
+
st.header(" 🎯 Objective")
|
15 |
+
st.write("""
|
16 |
+
The main objective of this application is to develop a machine learning model that accurately predicts customer churn. By identifying at-risk customers, businesses can take proactive measures to enhance customer retention and improve overall satisfaction.
|
17 |
+
""")
|
18 |
+
|
19 |
+
st.header("🛠 Technological Stack")
|
20 |
+
st.write("""
|
21 |
+
The application is built using the following technologies:
|
22 |
+
- **Python**: The primary programming language used for development.
|
23 |
+
- **Machine Learning**: Algorithms to analyze customer data and predict churn.
|
24 |
+
- **MLOps**: Practices for deploying and maintaining machine learning models.
|
25 |
+
- **ZenML**: A tool to create reproducible ML pipelines.
|
26 |
+
- **MLflow**: For tracking experiments and managing model lifecycle.
|
27 |
+
- **Streamlit**: A user-friendly UI framework for creating interactive web applications.
|
28 |
+
- **FastAPI**: Back-end frameworks to build APIs for model interactions.
|
29 |
+
- **Evidently Ai** : A tool for model monitering and data drift detection
|
30 |
+
""")
|
31 |
+
|
32 |
+
|
33 |
+
st.header(" 📝 Overview")
|
34 |
+
st.write("""
|
35 |
+
This design document outlines the development of a web application for predicting customer churn using a dataset that includes customer Usage Frequency, Tenure and historical behaviors. The application will allow users to input customer data manually and receive predictions on churn likelihood and suggested retention strategies.
|
36 |
+
""")
|
37 |
+
|
38 |
+
|
39 |
+
st.header(" 💪 Motivation")
|
40 |
+
st.write("""
|
41 |
+
Understanding and addressing customer churn can significantly enhance customer loyalty and reduce marketing costs associated with acquiring new customers. This application provides insights that help businesses to implement effective retention strategies.
|
42 |
+
""")
|
43 |
+
|
44 |
+
|
45 |
+
st.header(" 📈 Success Metrics")
|
46 |
+
st.write("""
|
47 |
+
The project's success will be measured using the following metrics:
|
48 |
+
- Precision, Recall, and F1 Score of the churn prediction model.
|
49 |
+
- User engagement and satisfaction with the application interface.
|
50 |
+
- Reduction in customer churn rates observed post-implementation.
|
51 |
+
""")
|
52 |
+
|
53 |
+
|
54 |
+
st.header(" ✍ Requirements & Constraints")
|
55 |
+
st.subheader(" ⚙️ Functional Requirements")
|
56 |
+
st.write("""
|
57 |
+
- Users can input customer data to receive churn predictions.
|
58 |
+
- Users can view performance metrics of the machine learning models.
|
59 |
+
- Users can visualize customer behavior data to derive insights.
|
60 |
+
""")
|
61 |
+
|
62 |
+
st.subheader(" 👨🏻💻 Non-Functional Requirements")
|
63 |
+
st.write("""
|
64 |
+
- The model should demonstrate high accuracy in predictions.
|
65 |
+
- The application should be responsive and user-friendly.
|
66 |
+
- User data must be handled securely.
|
67 |
+
""")
|
68 |
+
|
69 |
+
st.subheader("✅️ Constraints")
|
70 |
+
st.write("""
|
71 |
+
- The application must be built using FastAPI and Streamlit, with deployment on Docker and streamlit.
|
72 |
+
""")
|
73 |
+
|
74 |
+
st.header(" 🔄 Methodology")
|
75 |
+
st.write("""
|
76 |
+
- **Problem Statement**: Develop a model to predict customer churn based on various features.
|
77 |
+
- **Data**: Utilize a dataset containing customer-related features such as demographics and service usage.
|
78 |
+
- **Techniques**: Employ data preprocessing, feature engineering, model selection, training, and evaluation.
|
79 |
+
""")
|
80 |
+
|
81 |
+
|
82 |
+
st.header(" 🏛️ Architecture")
|
83 |
+
st.write("""
|
84 |
+
The architecture of the web application consists of:
|
85 |
+
- A **frontend** built using Streamlit for user interaction.
|
86 |
+
- A **backend** server implemented with FastAPI for handling requests and serving predictions.
|
87 |
+
- A **machine learning model** for churn prediction.
|
88 |
+
- Utilization of **Docker** for containerization.
|
89 |
+
|
90 |
+
""")
|
91 |
+
st.header(" 🖇️ Pipeline")
|
92 |
+
st.image("assets/Pipeline.png", caption="Pipeline", use_column_width=True)
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
home_page()
|
frontend/main.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from streamlit_option_menu import option_menu
|
4 |
+
from about import about_me
|
5 |
+
from project import project_ui
|
6 |
+
from home import home_page
|
7 |
+
from EDA import eda
|
8 |
+
from model import metrics_ui
|
9 |
+
|
10 |
+
# Sidebar for navigation
|
11 |
+
with st.sidebar:
|
12 |
+
selected = option_menu(
|
13 |
+
menu_title="Main Menu",
|
14 |
+
options=["Home", "Project", "EDA", "Model","About Me"],
|
15 |
+
icons=["house", "app-indicator", "bar-chart","person-video" ,"person-video3"],
|
16 |
+
menu_icon="cast",
|
17 |
+
default_index=1,
|
18 |
+
)
|
19 |
+
if selected == "Project":
|
20 |
+
project_ui()
|
21 |
+
if selected == "Home":
|
22 |
+
home_page()
|
23 |
+
|
24 |
+
if selected == "EDA":
|
25 |
+
eda()
|
26 |
+
|
27 |
+
if selected == "Model":
|
28 |
+
metrics_ui()
|
29 |
+
|
30 |
+
if selected == "About Me":
|
31 |
+
about_me()
|
frontend/make_prediction.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
|
5 |
+
# Helper function to send data to FastAPI for prediction
|
6 |
+
def get_prediction(input_data):
|
7 |
+
"""
|
8 |
+
Sends the input data to the FastAPI backend to get a prediction.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
input_data (pd.DataFrame): Input data to send to the FastAPI backend
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
dict or None: The JSON response from the FastAPI backend, or None if the request failed
|
15 |
+
"""
|
16 |
+
|
17 |
+
|
18 |
+
url = "http://127.0.0.1:8001/predict" # URL of the FastAPI backend
|
19 |
+
headers = {"Content-Type": "application/json"}
|
20 |
+
|
21 |
+
json_data = {
|
22 |
+
"dataframe_records": input_data.to_dict(orient="records")
|
23 |
+
}
|
24 |
+
|
25 |
+
try:
|
26 |
+
response = requests.post(url, headers=headers, json=json_data)
|
27 |
+
response.raise_for_status()
|
28 |
+
return response.json() # Return JSON response from FastAPI
|
29 |
+
except requests.exceptions.RequestException as e:
|
30 |
+
st.error(f"Request failed: {e}")
|
31 |
+
return None
|
frontend/model.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit.components.v1 as components
|
4 |
+
|
5 |
+
|
6 |
+
metrics_path = "assets/model_metrics.csv"
|
7 |
+
metrics_df = pd.read_csv(metrics_path)
|
8 |
+
|
9 |
+
def metrics_ui():
|
10 |
+
st.image("",width=500)
|
11 |
+
|
12 |
+
st.title("Model Evaluation Metrics")
|
13 |
+
st.subheader("Performance Metrics of Trained Models")
|
14 |
+
|
15 |
+
|
16 |
+
st.dataframe(metrics_df)
|
17 |
+
|
18 |
+
best_model_name = metrics_df.loc[metrics_df['F1 Score'].idxmax(), 'Model']
|
19 |
+
st.write(f"Best Model: {best_model_name} with an F1 Score of {metrics_df['F1 Score'].max():.2f}")
|
20 |
+
|
21 |
+
st.header("What is data drift ?")
|
22 |
+
st.write(
|
23 |
+
"""
|
24 |
+
Data Drift is a name for data change that can affect the Machine Learning model performance. There are different types
|
25 |
+
of drift.There can be change in target distribution.
|
26 |
+
|
27 |
+
For example you have a model that predicts house price based on the property description (no.of rooms,location,etc.,)
|
28 |
+
but there is a change in the market and all properities prices go up.
|
29 |
+
|
30 |
+
If you don't detect the drift in the target distribution and don't update the model with new targets then your model will predict too low pricies.
|
31 |
+
|
32 |
+
The chnage in the target distribution is so-called **target distribution.**
|
33 |
+
|
34 |
+
The next type of drift is **covariate drift**.it is a change in the input data distribution. For example, you have
|
35 |
+
a categorical feature that will start to have new category.
|
36 |
+
"""
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
st.header("what to do after data drift detection ?")
|
41 |
+
st.write(
|
42 |
+
"""
|
43 |
+
When data drift is detected that ML model should be updated. There are many ways in which it can be done - all depends on the
|
44 |
+
data. The most striaght forward way is to use all avaiable data samples to train a new model. The other approach might be to use just
|
45 |
+
a new data samples to train the model.
|
46 |
+
|
47 |
+
There might be also approaches with sample weighting - giving higher weight for fresh data samples and lower weights for old samples.
|
48 |
+
it depends on the data.
|
49 |
+
"""
|
50 |
+
)
|
51 |
+
|
52 |
+
st.header("How to detect data drift ?")
|
53 |
+
st.write("""
|
54 |
+
The data drift can be detected in different ways. The simplest approach is to use statistical tests that compare
|
55 |
+
the distribution of the trainig data and live data (production data). If the differnec between two distribution is significantly then a drift occured.
|
56 |
+
|
57 |
+
The most popular test are **two-sample, kolmogorov-Smirnov test,Chi square test, jensen-shannon divergence, Wasserstein distance.**
|
58 |
+
The alternative approach might be use Machine Learning model to monitar the data quality. There can be also hybrid approaches.
|
59 |
+
""")
|
60 |
+
|
61 |
+
|
62 |
+
st.subheader("Dataset Drift")
|
63 |
+
html = "frontend/reports/report.html"
|
64 |
+
with open(html,'r') as f:
|
65 |
+
html_data= f.read()
|
66 |
+
|
67 |
+
st.components.v1.html(html_data,scrolling = True,height=700,width= 800)
|
68 |
+
|
69 |
+
|
70 |
+
st.subheader("Decison Tree Model Report")
|
71 |
+
html = "frontend/reports/model_report_1.html"
|
72 |
+
with open(html,'r') as f:
|
73 |
+
html_data= f.read()
|
74 |
+
|
75 |
+
st.components.v1.html(html_data,scrolling = True,height=700,width= 800)
|
76 |
+
|
77 |
+
st.subheader("RandomForest Model Drift")
|
78 |
+
html = "frontend/reports/model_report_2.html"
|
79 |
+
with open(html,'r') as f:
|
80 |
+
html_data= f.read()
|
81 |
+
|
82 |
+
st.components.v1.html(html_data,scrolling = True,height=700,width= 800)
|
83 |
+
|
84 |
+
|
85 |
+
|
frontend/project.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from make_prediction import get_prediction
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
|
7 |
+
def project_ui():
|
8 |
+
st.image("assets/customer_churn_image.jpg",width=600)
|
9 |
+
st.title("Customer Churn Prediction")
|
10 |
+
|
11 |
+
|
12 |
+
age = st.number_input("**Age**", min_value=18, max_value=100, step=1)
|
13 |
+
gender = st.selectbox("**Gender**", options=["Male", "Female"])
|
14 |
+
gender_encoded = 1 if gender == "Male" else 0
|
15 |
+
|
16 |
+
tenure = st.number_input("**Tenure (months)**", min_value=0, step=1)
|
17 |
+
usage_frequency = st.number_input("**Usage Frequency**", min_value=0, step=1)
|
18 |
+
support_calls = st.number_input("**Support Calls**", min_value=0, step=1)
|
19 |
+
payment_delay = st.number_input("**Payment Delay**", min_value=0, step=1)
|
20 |
+
|
21 |
+
subscription_type = st.selectbox("**Subscription Type**", options=["Standard", "Basic", "Premium"])
|
22 |
+
subscription_type_encoded = {"Standard": 2, "Basic": 0, "Premium": 1}[subscription_type]
|
23 |
+
|
24 |
+
contract_length = st.selectbox("**Contract Length**", options=["Annual", "Monthly", "Quarterly"])
|
25 |
+
contract_length_encoded = {"Annual": 0, "Monthly": 1, "Quarterly": 2}[contract_length]
|
26 |
+
|
27 |
+
total_spend = st.number_input("**Total Spend**", min_value=0.0, step=1.0)
|
28 |
+
last_interaction = st.number_input("Last Interaction (days ago)", min_value=0, step=1)
|
29 |
+
|
30 |
+
# Create DataFrame of input data for the prediction
|
31 |
+
input_data = pd.DataFrame({
|
32 |
+
"Age": [age],
|
33 |
+
"Gender": [gender_encoded],
|
34 |
+
"Tenure": [tenure],
|
35 |
+
"Usage Frequency": [usage_frequency],
|
36 |
+
"Support Calls": [support_calls],
|
37 |
+
"Payment Delay": [payment_delay],
|
38 |
+
"Subscription Type": [subscription_type_encoded],
|
39 |
+
"Contract Length": [contract_length_encoded],
|
40 |
+
"Total Spend": [total_spend],
|
41 |
+
"Last Interaction": [last_interaction],
|
42 |
+
})
|
43 |
+
|
44 |
+
|
45 |
+
if st.button("Predict Churn"):
|
46 |
+
prediction = get_prediction(input_data)
|
47 |
+
|
48 |
+
|
49 |
+
if prediction is not None:
|
50 |
+
churn_value = int(prediction['predictions'][0])
|
51 |
+
churn_prediction = "Will Churn" if churn_value == 1 else "Won't Churn"
|
52 |
+
st.success(f"Prediction: {churn_prediction}")
|
53 |
+
else:
|
54 |
+
st.write("Prediction request failed. We are using local model ")
|
55 |
+
with open("backend/artifacts/XGBoost.pkl","rb") as file:
|
56 |
+
model= pickle.load(file)
|
57 |
+
result = model.predict(input_data)
|
58 |
+
churn_prediction = "Will Churn" if result ==1 else "Won't Churn"
|
59 |
+
st.success(f"Prediction : {churn_prediction}")
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
project_ui()
|
64 |
+
|
65 |
+
|
66 |
+
|
frontend/reports/model_report_1.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
frontend/reports/model_report_2.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|