Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,149 +1,47 @@
|
|
1 |
-
import gradio as gr
|
2 |
import pandas as pd
|
3 |
-
|
|
|
4 |
from sklearn.ensemble import RandomForestClassifier
|
5 |
-
from sklearn.
|
6 |
-
|
7 |
-
|
8 |
-
# --- Data Loading and Preprocessing ---
|
9 |
-
|
10 |
-
# Load the dataset named 'titanic.csv'
|
11 |
-
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
|
12 |
-
try:
|
13 |
-
df = pd.read_csv('titanic.csv')
|
14 |
-
except FileNotFoundError:
|
15 |
-
raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")
|
16 |
-
|
17 |
-
# Drop irrelevant columns and 'PassengerId' which is not a feature
|
18 |
-
# These columns are typically present in a standard Titanic dataset.
|
19 |
-
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
|
20 |
-
|
21 |
-
# Handle missing 'Age' with median imputation
|
22 |
-
df['Age'].fillna(df['Age'].median(), inplace=True)
|
23 |
-
|
24 |
-
# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
|
25 |
-
df['Fare'].fillna(df['Fare'].median(), inplace=True)
|
26 |
-
|
27 |
-
# Handle missing 'Embarked' with mode imputation
|
28 |
-
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
|
29 |
-
|
30 |
-
# Convert categorical features to numerical using one-hot encoding
|
31 |
-
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
|
32 |
-
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
y = df['Survived']
|
38 |
|
39 |
-
#
|
40 |
-
|
|
|
41 |
|
42 |
-
# Train
|
43 |
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
44 |
-
model.fit(
|
45 |
-
|
46 |
-
# --- Model Evaluation (for display in app) ---
|
47 |
-
y_pred = model.predict(X_test)
|
48 |
-
accuracy = accuracy_score(y_test, y_pred)
|
49 |
-
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
|
50 |
-
|
51 |
-
# --- Prediction Function for Gradio ---
|
52 |
-
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
|
53 |
-
# Create a dictionary for the input values
|
54 |
-
input_dict = {
|
55 |
-
'Pclass': pclass,
|
56 |
-
'Age': age,
|
57 |
-
'SibSp': sibsp,
|
58 |
-
'Parch': parch,
|
59 |
-
'Fare': fare,
|
60 |
-
# These match the one-hot encoded columns created during training
|
61 |
-
'Sex_male': 1 if sex == 'male' else 0,
|
62 |
-
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
|
63 |
-
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
|
64 |
-
}
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
# This handles cases where a category might not be present in a single input but was in training
|
71 |
-
for col in X.columns:
|
72 |
-
if col not in input_data.columns:
|
73 |
-
input_data[col] = 0
|
74 |
-
|
75 |
-
# Reorder columns to match the training data's column order
|
76 |
-
input_data = input_data[X.columns]
|
77 |
-
|
78 |
-
# Make prediction
|
79 |
prediction = model.predict(input_data)[0]
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
""
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
|
101 |
-
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
|
102 |
-
with gr.Row():
|
103 |
-
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
|
104 |
-
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
|
105 |
-
fare_input = gr.Number(label="Fare", value=30.0)
|
106 |
-
with gr.Row():
|
107 |
-
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
|
108 |
-
|
109 |
-
predict_btn = gr.Button("Predict Survival")
|
110 |
-
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
|
111 |
-
# This label is used internally to get the color, its content is not directly shown
|
112 |
-
output_color_indicator = gr.Label(visible=False)
|
113 |
-
|
114 |
-
# Function to update the textbox styling based on prediction
|
115 |
-
def update_output_style(text, color):
|
116 |
-
if color == "green":
|
117 |
-
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="green")
|
118 |
-
elif color == "red":
|
119 |
-
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="red")
|
120 |
-
else:
|
121 |
-
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
|
122 |
-
|
123 |
-
predict_btn.click(
|
124 |
-
fn=predict_survival,
|
125 |
-
inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
|
126 |
-
outputs=[output_text, output_color_indicator]
|
127 |
-
).then(
|
128 |
-
fn=update_output_style,
|
129 |
-
inputs=[output_text, output_color_indicator],
|
130 |
-
outputs=output_text
|
131 |
-
)
|
132 |
-
|
133 |
-
gr.Markdown(
|
134 |
-
"""
|
135 |
-
---
|
136 |
-
**Feature Definitions:**
|
137 |
-
* **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
|
138 |
-
* **Sex:** Sex (male/female)
|
139 |
-
* **Age:** Age in years
|
140 |
-
* **SibSp:** Number of siblings/spouses aboard the Titanic
|
141 |
-
* **Parch:** Number of parents/children aboard the Titanic
|
142 |
-
* **Fare:** Passenger fare
|
143 |
-
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
|
144 |
-
|
145 |
-
*Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
|
146 |
-
"""
|
147 |
-
)
|
148 |
-
|
149 |
-
demo.launch()
|
|
|
|
|
1 |
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
from sklearn.ensemble import RandomForestClassifier
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
# Load dataset
|
9 |
+
df = pd.read_csv("titanic.csv")
|
10 |
+
df = df[["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Survived"]].dropna()
|
11 |
|
12 |
+
# Encode 'Sex'
|
13 |
+
df["Sex"] = LabelEncoder().fit_transform(df["Sex"]) # male=1, female=0
|
|
|
14 |
|
15 |
+
# Features & target
|
16 |
+
X = df.drop("Survived", axis=1)
|
17 |
+
y = df["Survived"]
|
18 |
|
19 |
+
# Train model
|
20 |
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
21 |
+
model.fit(X, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Prediction function
|
24 |
+
def predict_survival(pclass, sex, age, sibsp, parch, fare):
|
25 |
+
sex_encoded = 1 if sex == "male" else 0
|
26 |
+
input_data = np.array([[pclass, sex_encoded, age, sibsp, parch, fare]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
prediction = model.predict(input_data)[0]
|
28 |
+
return "✅ Survived" if prediction == 1 else "❌ Did not survive"
|
29 |
+
|
30 |
+
# Gradio interface
|
31 |
+
iface = gr.Interface(
|
32 |
+
fn=predict_survival,
|
33 |
+
inputs=[
|
34 |
+
gr.Dropdown([1, 2, 3], label="Passenger Class"),
|
35 |
+
gr.Radio(["male", "female"], label="Sex"),
|
36 |
+
gr.Slider(0, 80, step=1, label="Age"),
|
37 |
+
gr.Slider(0, 10, step=1, label="Siblings/Spouses Aboard"),
|
38 |
+
gr.Slider(0, 10, step=1, label="Parents/Children Aboard"),
|
39 |
+
gr.Slider(0, 500, step=1, label="Fare"),
|
40 |
+
],
|
41 |
+
outputs="text",
|
42 |
+
title="🚢 Titanic Survival Predictor",
|
43 |
+
description="Enter passenger details to predict their survival on the Titanic."
|
44 |
+
)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|