Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,93 +1,111 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
-
from sklearn.
|
5 |
-
from sklearn.
|
6 |
-
|
7 |
-
from datasets import load_dataset # To load the dataset from Hugging Face
|
8 |
|
9 |
# --- Data Loading and Preprocessing ---
|
10 |
|
11 |
-
# Load the
|
|
|
12 |
try:
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
except Exception as e:
|
17 |
-
gr.Warning(f"Failed to load dataset: {e}. Please check your internet connection or dataset name.")
|
18 |
-
# Provide a minimal fallback for local testing if HF dataset loading fails
|
19 |
-
df = pd.DataFrame({
|
20 |
-
'text': [
|
21 |
-
"A young boy, probably a steerage passenger. He doesn't look like he survived.",
|
22 |
-
"A first-class lady with a child. She likely survived due to priority.",
|
23 |
-
"Male, 30s, middle class. Probably didn't make it.",
|
24 |
-
"Female, 20s, dressed finely. Looks like she got on a lifeboat.",
|
25 |
-
"An elderly man, alone, traveling steerage."
|
26 |
-
],
|
27 |
-
'label': ["0", "1", "0", "1", "0"] # 0 for died, 1 for survived
|
28 |
-
})
|
29 |
-
|
30 |
-
|
31 |
-
# Ensure 'label' column is numeric
|
32 |
-
df['label'] = pd.to_numeric(df['label'])
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
|
41 |
-
#
|
42 |
-
|
43 |
-
vectorizer = CountVectorizer(stop_words='english', lowercase=True)
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
X_test_vectorized = vectorizer.transform(X_test)
|
48 |
|
49 |
-
#
|
|
|
|
|
50 |
|
51 |
-
# Initialize and train the Multinomial Naive Bayes model
|
52 |
-
model = MultinomialNB()
|
53 |
-
model.fit(X_train_vectorized, y_train)
|
54 |
|
55 |
-
#
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
accuracy = accuracy_score(y_test, y_pred)
|
59 |
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
|
60 |
|
61 |
# --- Prediction Function for Gradio ---
|
62 |
-
def
|
63 |
-
#
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Make prediction
|
67 |
-
prediction = model.predict(
|
68 |
-
prediction_proba = model.predict_proba(
|
69 |
|
70 |
-
if prediction == 1:
|
71 |
-
return f"Prediction:
|
72 |
-
else:
|
73 |
-
return f"Prediction:
|
74 |
|
75 |
# --- Gradio Interface ---
|
|
|
76 |
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
|
77 |
gr.Markdown(
|
78 |
"""
|
79 |
-
# Titanic Survival Predictor
|
80 |
-
Enter
|
81 |
-
This model uses text classification techniques.
|
82 |
"""
|
83 |
)
|
84 |
-
gr.Markdown(f"### {accuracy_message}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
description_input = gr.Textbox(
|
87 |
-
label="Enter Passenger Description",
|
88 |
-
lines=5,
|
89 |
-
placeholder="e.g., 'A young woman from first class, traveling alone.'"
|
90 |
-
)
|
91 |
predict_btn = gr.Button("Predict Survival")
|
92 |
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
|
93 |
# This label is used internally to get the color, its content is not directly shown
|
@@ -103,8 +121,8 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
|
|
103 |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
|
104 |
|
105 |
predict_btn.click(
|
106 |
-
fn=
|
107 |
-
inputs=
|
108 |
outputs=[output_text, output_color_indicator]
|
109 |
).then(
|
110 |
fn=update_output_style,
|
@@ -112,25 +130,19 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
|
|
112 |
outputs=output_text
|
113 |
)
|
114 |
|
115 |
-
gr.Examples(
|
116 |
-
examples=[
|
117 |
-
"A wealthy first-class woman with her child. She was probably on a lifeboat.",
|
118 |
-
"An old man, alone, traveling in third class. He likely did not survive.",
|
119 |
-
"A young male crew member.",
|
120 |
-
"A small child from steerage."
|
121 |
-
],
|
122 |
-
inputs=description_input,
|
123 |
-
outputs=[output_text, output_color_indicator],
|
124 |
-
fn=predict_survival_from_text,
|
125 |
-
cache_examples=True # Caches the output for examples for faster loading
|
126 |
-
)
|
127 |
-
|
128 |
gr.Markdown(
|
129 |
"""
|
130 |
---
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
"""
|
135 |
)
|
136 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
+
from sklearn.ensemble import RandomForestClassifier
|
5 |
+
from sklearn.metrics import accuracy_score
|
6 |
+
import io # Keep io, though not strictly used in this version, it's harmless.
|
|
|
7 |
|
8 |
# --- Data Loading and Preprocessing ---
|
9 |
|
10 |
+
# Load the dataset named 'titanic.csv'
|
11 |
+
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
|
12 |
try:
|
13 |
+
df = pd.read_csv('titanic.csv')
|
14 |
+
except FileNotFoundError:
|
15 |
+
raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# Drop irrelevant columns and 'PassengerId' which is not a feature
|
18 |
+
# These columns are typically present in a standard Titanic dataset.
|
19 |
+
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
|
20 |
|
21 |
+
# Handle missing 'Age' with median imputation
|
22 |
+
df['Age'].fillna(df['Age'].median(), inplace=True)
|
23 |
|
24 |
+
# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
|
25 |
+
df['Fare'].fillna(df['Fare'].median(), inplace=True)
|
|
|
26 |
|
27 |
+
# Handle missing 'Embarked' with mode imputation
|
28 |
+
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
|
|
|
29 |
|
30 |
+
# Convert categorical features to numerical using one-hot encoding
|
31 |
+
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
|
32 |
+
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
+
# Define features (X) and target (y)
|
36 |
+
X = df.drop('Survived', axis=1)
|
37 |
+
y = df['Survived']
|
38 |
|
39 |
+
# Split data into training and testing sets
|
40 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
41 |
+
|
42 |
+
# Train a RandomForestClassifier model
|
43 |
+
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
44 |
+
model.fit(X_train, y_train)
|
45 |
+
|
46 |
+
# --- Model Evaluation (for display in app) ---
|
47 |
+
y_pred = model.predict(X_test)
|
48 |
accuracy = accuracy_score(y_test, y_pred)
|
49 |
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
|
50 |
|
51 |
# --- Prediction Function for Gradio ---
|
52 |
+
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
|
53 |
+
# Create a dictionary for the input values
|
54 |
+
input_dict = {
|
55 |
+
'Pclass': pclass,
|
56 |
+
'Age': age,
|
57 |
+
'SibSp': sibsp,
|
58 |
+
'Parch': parch,
|
59 |
+
'Fare': fare,
|
60 |
+
# These match the one-hot encoded columns created during training
|
61 |
+
'Sex_male': 1 if sex == 'male' else 0,
|
62 |
+
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
|
63 |
+
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
|
64 |
+
}
|
65 |
+
|
66 |
+
# Create a DataFrame from the input values
|
67 |
+
input_data = pd.DataFrame([input_dict])
|
68 |
+
|
69 |
+
# Ensure all columns expected by the model are present in the input_data, even if 0
|
70 |
+
# This handles cases where a category might not be present in a single input but was in training
|
71 |
+
for col in X.columns:
|
72 |
+
if col not in input_data.columns:
|
73 |
+
input_data[col] = 0
|
74 |
+
|
75 |
+
# Reorder columns to match the training data's column order
|
76 |
+
input_data = input_data[X.columns]
|
77 |
|
78 |
# Make prediction
|
79 |
+
prediction = model.predict(input_data)[0]
|
80 |
+
prediction_proba = model.predict_proba(input_data)[0]
|
81 |
|
82 |
+
if prediction == 1:
|
83 |
+
return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
|
84 |
+
else:
|
85 |
+
return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"
|
86 |
|
87 |
# --- Gradio Interface ---
|
88 |
+
# CSS to style the output textbox background
|
89 |
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
|
90 |
gr.Markdown(
|
91 |
"""
|
92 |
+
# Titanic Survival Predictor
|
93 |
+
Enter passenger details to predict their survival on the Titanic.
|
|
|
94 |
"""
|
95 |
)
|
96 |
+
gr.Markdown(f"### Model Performance: {accuracy_message}")
|
97 |
+
|
98 |
+
with gr.Row():
|
99 |
+
pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
|
100 |
+
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
|
101 |
+
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
|
102 |
+
with gr.Row():
|
103 |
+
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
|
104 |
+
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
|
105 |
+
fare_input = gr.Number(label="Fare", value=30.0)
|
106 |
+
with gr.Row():
|
107 |
+
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
|
108 |
|
|
|
|
|
|
|
|
|
|
|
109 |
predict_btn = gr.Button("Predict Survival")
|
110 |
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
|
111 |
# This label is used internally to get the color, its content is not directly shown
|
|
|
121 |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
|
122 |
|
123 |
predict_btn.click(
|
124 |
+
fn=predict_survival,
|
125 |
+
inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
|
126 |
outputs=[output_text, output_color_indicator]
|
127 |
).then(
|
128 |
fn=update_output_style,
|
|
|
130 |
outputs=output_text
|
131 |
)
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
gr.Markdown(
|
134 |
"""
|
135 |
---
|
136 |
+
**Feature Definitions:**
|
137 |
+
* **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
|
138 |
+
* **Sex:** Sex (male/female)
|
139 |
+
* **Age:** Age in years
|
140 |
+
* **SibSp:** Number of siblings/spouses aboard the Titanic
|
141 |
+
* **Parch:** Number of parents/children aboard the Titanic
|
142 |
+
* **Fare:** Passenger fare
|
143 |
+
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
|
144 |
+
|
145 |
+
*Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
|
146 |
"""
|
147 |
)
|
148 |
|