Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,116 +1,93 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
-
from sklearn.
|
5 |
-
from sklearn.
|
|
|
6 |
from datasets import load_dataset # To load the dataset from Hugging Face
|
7 |
|
8 |
# --- Data Loading and Preprocessing ---
|
9 |
|
10 |
-
# Load the
|
11 |
-
# This dataset is commonly available on HF and mirrors the Kaggle structure.
|
12 |
try:
|
13 |
-
#
|
14 |
-
dataset = load_dataset("
|
15 |
-
df = pd.DataFrame(dataset) # Convert to pandas DataFrame
|
16 |
except Exception as e:
|
17 |
-
|
18 |
-
#
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
|
29 |
-
#
|
30 |
-
|
|
|
31 |
|
32 |
-
#
|
33 |
-
|
|
|
34 |
|
35 |
-
#
|
36 |
-
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
|
37 |
-
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
|
38 |
|
|
|
|
|
|
|
39 |
|
40 |
-
#
|
41 |
-
X = df.drop('Survived', axis=1)
|
42 |
-
y = df['Survived']
|
43 |
|
44 |
-
|
45 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
46 |
-
|
47 |
-
# Train a RandomForestClassifier model
|
48 |
-
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
49 |
-
model.fit(X_train, y_train)
|
50 |
-
|
51 |
-
# Evaluate the model (for display purposes)
|
52 |
-
y_pred = model.predict(X_test)
|
53 |
accuracy = accuracy_score(y_test, y_pred)
|
54 |
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
|
55 |
|
56 |
# --- Prediction Function for Gradio ---
|
57 |
-
def
|
58 |
-
#
|
59 |
-
|
60 |
-
'Pclass': pclass,
|
61 |
-
'Age': age,
|
62 |
-
'SibSp': sibsp,
|
63 |
-
'Parch': parch,
|
64 |
-
'Fare': fare,
|
65 |
-
# These match the one-hot encoded columns created during training
|
66 |
-
'Sex_male': 1 if sex == 'male' else 0,
|
67 |
-
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
|
68 |
-
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
|
69 |
-
}
|
70 |
-
|
71 |
-
# Create a DataFrame from the input values
|
72 |
-
input_data = pd.DataFrame([input_dict])
|
73 |
-
|
74 |
-
# Ensure all columns expected by the model are present in the input_data, even if 0
|
75 |
-
# This handles cases where a category might not be present in a single input but was in training
|
76 |
-
for col in X.columns:
|
77 |
-
if col not in input_data.columns:
|
78 |
-
input_data[col] = 0
|
79 |
-
|
80 |
-
# Reorder columns to match the training data's column order
|
81 |
-
input_data = input_data[X.columns]
|
82 |
|
83 |
# Make prediction
|
84 |
-
prediction = model.predict(
|
85 |
-
prediction_proba = model.predict_proba(
|
86 |
|
87 |
-
if prediction == 1:
|
88 |
-
return f"Prediction:
|
89 |
-
else:
|
90 |
-
return f"Prediction:
|
91 |
|
92 |
# --- Gradio Interface ---
|
93 |
-
# CSS to style the output textbox background
|
94 |
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
|
95 |
gr.Markdown(
|
96 |
"""
|
97 |
-
# Titanic Survival Predictor
|
98 |
-
Enter passenger
|
|
|
99 |
"""
|
100 |
)
|
101 |
-
gr.Markdown(f"###
|
102 |
-
|
103 |
-
with gr.Row():
|
104 |
-
pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
|
105 |
-
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
|
106 |
-
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
|
107 |
-
with gr.Row():
|
108 |
-
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
|
109 |
-
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
|
110 |
-
fare_input = gr.Number(label="Fare", value=30.0)
|
111 |
-
with gr.Row():
|
112 |
-
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
|
113 |
|
|
|
|
|
|
|
|
|
|
|
114 |
predict_btn = gr.Button("Predict Survival")
|
115 |
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
|
116 |
# This label is used internally to get the color, its content is not directly shown
|
@@ -126,8 +103,8 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
|
|
126 |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
|
127 |
|
128 |
predict_btn.click(
|
129 |
-
fn=
|
130 |
-
inputs=
|
131 |
outputs=[output_text, output_color_indicator]
|
132 |
).then(
|
133 |
fn=update_output_style,
|
@@ -135,19 +112,25 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
|
|
135 |
outputs=output_text
|
136 |
)
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
gr.Markdown(
|
139 |
"""
|
140 |
---
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
* **Age:** Age in years
|
145 |
-
* **SibSp:** Number of siblings/spouses aboard the Titanic
|
146 |
-
* **Parch:** Number of parents/children aboard the Titanic
|
147 |
-
* **Fare:** Passenger fare
|
148 |
-
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
|
149 |
-
|
150 |
-
*Note: The dataset is loaded directly from Hugging Face's `datasets` library ([AbubakarJ/titanic](https://huggingface.co/datasets/AbubakarJ/titanic)). Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
|
151 |
"""
|
152 |
)
|
153 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
5 |
+
from sklearn.naive_bayes import MultinomialNB
|
6 |
+
from sklearn.metrics import accuracy_score, classification_report
|
7 |
from datasets import load_dataset # To load the dataset from Hugging Face
|
8 |
|
9 |
# --- Data Loading and Preprocessing ---
|
10 |
|
11 |
+
# Load the julien-c/titanic-survival dataset from Hugging Face
|
|
|
12 |
try:
|
13 |
+
# This dataset typically contains 'text' (description) and 'label' (0=died, 1=survived)
|
14 |
+
dataset = load_dataset("julien-c/titanic-survival", split="train")
|
15 |
+
df = pd.DataFrame(dataset) # Convert the dataset to a pandas DataFrame
|
16 |
except Exception as e:
|
17 |
+
gr.Warning(f"Failed to load dataset: {e}. Please check your internet connection or dataset name.")
|
18 |
+
# Provide a minimal fallback for local testing if HF dataset loading fails
|
19 |
+
df = pd.DataFrame({
|
20 |
+
'text': [
|
21 |
+
"A young boy, probably a steerage passenger. He doesn't look like he survived.",
|
22 |
+
"A first-class lady with a child. She likely survived due to priority.",
|
23 |
+
"Male, 30s, middle class. Probably didn't make it.",
|
24 |
+
"Female, 20s, dressed finely. Looks like she got on a lifeboat.",
|
25 |
+
"An elderly man, alone, traveling steerage."
|
26 |
+
],
|
27 |
+
'label': ["0", "1", "0", "1", "0"] # 0 for died, 1 for survived
|
28 |
+
})
|
29 |
+
|
30 |
+
|
31 |
+
# Ensure 'label' column is numeric
|
32 |
+
df['label'] = pd.to_numeric(df['label'])
|
33 |
|
34 |
+
# Define features (X) and target (y)
|
35 |
+
X = df['text']
|
36 |
+
y = df['label']
|
37 |
|
38 |
+
# Split data into training and testing sets
|
39 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
40 |
|
41 |
+
# Initialize CountVectorizer
|
42 |
+
# This converts text documents to a matrix of token counts
|
43 |
+
vectorizer = CountVectorizer(stop_words='english', lowercase=True)
|
44 |
|
45 |
+
# Fit the vectorizer on the training data and transform both training and test data
|
46 |
+
X_train_vectorized = vectorizer.fit_transform(X_train)
|
47 |
+
X_test_vectorized = vectorizer.transform(X_test)
|
48 |
|
49 |
+
# --- Model Training ---
|
|
|
|
|
50 |
|
51 |
+
# Initialize and train the Multinomial Naive Bayes model
|
52 |
+
model = MultinomialNB()
|
53 |
+
model.fit(X_train_vectorized, y_train)
|
54 |
|
55 |
+
# --- Model Evaluation (for display in app) ---
|
|
|
|
|
56 |
|
57 |
+
y_pred = model.predict(X_test_vectorized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
accuracy = accuracy_score(y_test, y_pred)
|
59 |
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
|
60 |
|
61 |
# --- Prediction Function for Gradio ---
|
62 |
+
def predict_survival_from_text(passenger_description):
|
63 |
+
# Transform the input description using the *trained* vectorizer
|
64 |
+
message_vectorized = vectorizer.transform([passenger_description])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Make prediction
|
67 |
+
prediction = model.predict(message_vectorized)[0]
|
68 |
+
prediction_proba = model.predict_proba(message_vectorized)[0]
|
69 |
|
70 |
+
if prediction == 1: # 1 corresponds to 'survived'
|
71 |
+
return f"Prediction: SURVIVED ({prediction_proba[1]:.2%} confidence)", "green"
|
72 |
+
else: # 0 corresponds to 'died'
|
73 |
+
return f"Prediction: DID NOT SURVIVE ({prediction_proba[0]:.2%} confidence)", "red"
|
74 |
|
75 |
# --- Gradio Interface ---
|
|
|
76 |
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
|
77 |
gr.Markdown(
|
78 |
"""
|
79 |
+
# Titanic Survival Predictor (Text-based)
|
80 |
+
Enter a textual description of a passenger to predict their survival on the Titanic.
|
81 |
+
This model uses text classification techniques.
|
82 |
"""
|
83 |
)
|
84 |
+
gr.Markdown(f"### {accuracy_message}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
description_input = gr.Textbox(
|
87 |
+
label="Enter Passenger Description",
|
88 |
+
lines=5,
|
89 |
+
placeholder="e.g., 'A young woman from first class, traveling alone.'"
|
90 |
+
)
|
91 |
predict_btn = gr.Button("Predict Survival")
|
92 |
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
|
93 |
# This label is used internally to get the color, its content is not directly shown
|
|
|
103 |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
|
104 |
|
105 |
predict_btn.click(
|
106 |
+
fn=predict_survival_from_text,
|
107 |
+
inputs=description_input,
|
108 |
outputs=[output_text, output_color_indicator]
|
109 |
).then(
|
110 |
fn=update_output_style,
|
|
|
112 |
outputs=output_text
|
113 |
)
|
114 |
|
115 |
+
gr.Examples(
|
116 |
+
examples=[
|
117 |
+
"A wealthy first-class woman with her child. She was probably on a lifeboat.",
|
118 |
+
"An old man, alone, traveling in third class. He likely did not survive.",
|
119 |
+
"A young male crew member.",
|
120 |
+
"A small child from steerage."
|
121 |
+
],
|
122 |
+
inputs=description_input,
|
123 |
+
outputs=[output_text, output_color_indicator],
|
124 |
+
fn=predict_survival_from_text,
|
125 |
+
cache_examples=True # Caches the output for examples for faster loading
|
126 |
+
)
|
127 |
+
|
128 |
gr.Markdown(
|
129 |
"""
|
130 |
---
|
131 |
+
*This model uses a Multinomial Naive Bayes classifier. It is trained on the 'text' descriptions from the
|
132 |
+
[julien-c/titanic-survival](https://huggingface.co/datasets/julien-c/titanic-survival) dataset
|
133 |
+
to predict survival ('0' for died, '1' for survived). Text is preprocessed using CountVectorizer.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
"""
|
135 |
)
|
136 |
|