Spaces:
Sleeping
Sleeping
File size: 6,206 Bytes
a822fa8 42a186f 362698c a822fa8 68f838a a822fa8 362698c 68f838a 362698c 68f838a 362698c 42a186f 362698c 42a186f 362698c 42a186f 362698c a822fa8 362698c 42a186f 68f838a 362698c 42a186f 362698c 42a186f 362698c a822fa8 42a186f 362698c 42a186f 362698c 42a186f 362698c 42a186f a822fa8 362698c a822fa8 362698c 42a186f 362698c 42a186f a822fa8 362698c a822fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import io # Keep io, though not strictly used in this version, it's harmless.
# --- Data Loading and Preprocessing ---
# Load the dataset named 'titanic.csv'
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
try:
df = pd.read_csv('titanic.csv')
except FileNotFoundError:
raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")
# Drop irrelevant columns and 'PassengerId' which is not a feature
# These columns are typically present in a standard Titanic dataset.
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
# Handle missing 'Age' with median imputation
df['Age'].fillna(df['Age'].median(), inplace=True)
# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
df['Fare'].fillna(df['Fare'].median(), inplace=True)
# Handle missing 'Embarked' with mode imputation
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
# Convert categorical features to numerical using one-hot encoding
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
# Define features (X) and target (y)
X = df.drop('Survived', axis=1)
y = df['Survived']
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a RandomForestClassifier model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# --- Model Evaluation (for display in app) ---
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
# --- Prediction Function for Gradio ---
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
# Create a dictionary for the input values
input_dict = {
'Pclass': pclass,
'Age': age,
'SibSp': sibsp,
'Parch': parch,
'Fare': fare,
# These match the one-hot encoded columns created during training
'Sex_male': 1 if sex == 'male' else 0,
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
}
# Create a DataFrame from the input values
input_data = pd.DataFrame([input_dict])
# Ensure all columns expected by the model are present in the input_data, even if 0
# This handles cases where a category might not be present in a single input but was in training
for col in X.columns:
if col not in input_data.columns:
input_data[col] = 0
# Reorder columns to match the training data's column order
input_data = input_data[X.columns]
# Make prediction
prediction = model.predict(input_data)[0]
prediction_proba = model.predict_proba(input_data)[0]
if prediction == 1:
return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
else:
return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"
# --- Gradio Interface ---
# CSS to style the output textbox background
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
gr.Markdown(
"""
# Titanic Survival Predictor
Enter passenger details to predict their survival on the Titanic.
"""
)
gr.Markdown(f"### Model Performance: {accuracy_message}")
with gr.Row():
pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
with gr.Row():
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
fare_input = gr.Number(label="Fare", value=30.0)
with gr.Row():
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
predict_btn = gr.Button("Predict Survival")
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
# This label is used internally to get the color, its content is not directly shown
output_color_indicator = gr.Label(visible=False)
# Function to update the textbox styling based on prediction
def update_output_style(text, color):
if color == "green":
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="green")
elif color == "red":
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="red")
else:
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
predict_btn.click(
fn=predict_survival,
inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
outputs=[output_text, output_color_indicator]
).then(
fn=update_output_style,
inputs=[output_text, output_color_indicator],
outputs=output_text
)
gr.Markdown(
"""
---
**Feature Definitions:**
* **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
* **Sex:** Sex (male/female)
* **Age:** Age in years
* **SibSp:** Number of siblings/spouses aboard the Titanic
* **Parch:** Number of parents/children aboard the Titanic
* **Fare:** Passenger fare
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
*Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
"""
)
demo.launch() |