Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.metrics import accuracy_score | |
import io # Keep io, though not strictly used in this version, it's harmless. | |
# --- Data Loading and Preprocessing --- | |
# Load the dataset named 'titanic.csv' | |
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory | |
try: | |
df = pd.read_csv('titanic.csv') | |
except FileNotFoundError: | |
raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.") | |
# Drop irrelevant columns and 'PassengerId' which is not a feature | |
# These columns are typically present in a standard Titanic dataset. | |
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1) | |
# Handle missing 'Age' with median imputation | |
df['Age'].fillna(df['Age'].median(), inplace=True) | |
# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes) | |
df['Fare'].fillna(df['Fare'].median(), inplace=True) | |
# Handle missing 'Embarked' with mode imputation | |
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True) | |
# Convert categorical features to numerical using one-hot encoding | |
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice) | |
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True) | |
# Define features (X) and target (y) | |
X = df.drop('Survived', axis=1) | |
y = df['Survived'] | |
# Split data into training and testing sets | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
# Train a RandomForestClassifier model | |
model = RandomForestClassifier(n_estimators=100, random_state=42) | |
model.fit(X_train, y_train) | |
# --- Model Evaluation (for display in app) --- | |
y_pred = model.predict(X_test) | |
accuracy = accuracy_score(y_test, y_pred) | |
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}" | |
# --- Prediction Function for Gradio --- | |
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked): | |
# Create a dictionary for the input values | |
input_dict = { | |
'Pclass': pclass, | |
'Age': age, | |
'SibSp': sibsp, | |
'Parch': parch, | |
'Fare': fare, | |
# These match the one-hot encoded columns created during training | |
'Sex_male': 1 if sex == 'male' else 0, | |
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q' | |
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S' | |
} | |
# Create a DataFrame from the input values | |
input_data = pd.DataFrame([input_dict]) | |
# Ensure all columns expected by the model are present in the input_data, even if 0 | |
# This handles cases where a category might not be present in a single input but was in training | |
for col in X.columns: | |
if col not in input_data.columns: | |
input_data[col] = 0 | |
# Reorder columns to match the training data's column order | |
input_data = input_data[X.columns] | |
# Make prediction | |
prediction = model.predict(input_data)[0] | |
prediction_proba = model.predict_proba(input_data)[0] | |
if prediction == 1: | |
return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green" | |
else: | |
return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red" | |
# --- Gradio Interface --- | |
# CSS to style the output textbox background | |
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo: | |
gr.Markdown( | |
""" | |
# Titanic Survival Predictor | |
Enter passenger details to predict their survival on the Titanic. | |
""" | |
) | |
gr.Markdown(f"### Model Performance: {accuracy_message}") | |
with gr.Row(): | |
pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3) | |
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male') | |
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5) | |
with gr.Row(): | |
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0) | |
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0) | |
fare_input = gr.Number(label="Fare", value=30.0) | |
with gr.Row(): | |
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S') | |
predict_btn = gr.Button("Predict Survival") | |
output_text = gr.Textbox(label="Survival Prediction", interactive=False) | |
# This label is used internally to get the color, its content is not directly shown | |
output_color_indicator = gr.Label(visible=False) | |
# Function to update the textbox styling based on prediction | |
def update_output_style(text, color): | |
if color == "green": | |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="green") | |
elif color == "red": | |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="red") | |
else: | |
return gr.Textbox(value=text, label="Survival Prediction", interactive=False) | |
predict_btn.click( | |
fn=predict_survival, | |
inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input], | |
outputs=[output_text, output_color_indicator] | |
).then( | |
fn=update_output_style, | |
inputs=[output_text, output_color_indicator], | |
outputs=output_text | |
) | |
gr.Markdown( | |
""" | |
--- | |
**Feature Definitions:** | |
* **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd) | |
* **Sex:** Sex (male/female) | |
* **Age:** Age in years | |
* **SibSp:** Number of siblings/spouses aboard the Titanic | |
* **Parch:** Number of parents/children aboard the Titanic | |
* **Fare:** Passenger fare | |
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton) | |
*Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.* | |
""" | |
) | |
demo.launch() |