resolverkatla's picture
Update app.py
362698c verified
raw
history blame
6.21 kB
import gradio as gr
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import io # Keep io, though not strictly used in this version, it's harmless.
# --- Data Loading and Preprocessing ---
# Load the dataset named 'titanic.csv'
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
try:
df = pd.read_csv('titanic.csv')
except FileNotFoundError:
raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")
# Drop irrelevant columns and 'PassengerId' which is not a feature
# These columns are typically present in a standard Titanic dataset.
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
# Handle missing 'Age' with median imputation
df['Age'].fillna(df['Age'].median(), inplace=True)
# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
df['Fare'].fillna(df['Fare'].median(), inplace=True)
# Handle missing 'Embarked' with mode imputation
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
# Convert categorical features to numerical using one-hot encoding
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
# Define features (X) and target (y)
X = df.drop('Survived', axis=1)
y = df['Survived']
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a RandomForestClassifier model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# --- Model Evaluation (for display in app) ---
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
# --- Prediction Function for Gradio ---
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
# Create a dictionary for the input values
input_dict = {
'Pclass': pclass,
'Age': age,
'SibSp': sibsp,
'Parch': parch,
'Fare': fare,
# These match the one-hot encoded columns created during training
'Sex_male': 1 if sex == 'male' else 0,
'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
}
# Create a DataFrame from the input values
input_data = pd.DataFrame([input_dict])
# Ensure all columns expected by the model are present in the input_data, even if 0
# This handles cases where a category might not be present in a single input but was in training
for col in X.columns:
if col not in input_data.columns:
input_data[col] = 0
# Reorder columns to match the training data's column order
input_data = input_data[X.columns]
# Make prediction
prediction = model.predict(input_data)[0]
prediction_proba = model.predict_proba(input_data)[0]
if prediction == 1:
return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
else:
return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"
# --- Gradio Interface ---
# CSS to style the output textbox background
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
gr.Markdown(
"""
# Titanic Survival Predictor
Enter passenger details to predict their survival on the Titanic.
"""
)
gr.Markdown(f"### Model Performance: {accuracy_message}")
with gr.Row():
pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
with gr.Row():
sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
fare_input = gr.Number(label="Fare", value=30.0)
with gr.Row():
embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
predict_btn = gr.Button("Predict Survival")
output_text = gr.Textbox(label="Survival Prediction", interactive=False)
# This label is used internally to get the color, its content is not directly shown
output_color_indicator = gr.Label(visible=False)
# Function to update the textbox styling based on prediction
def update_output_style(text, color):
if color == "green":
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="green")
elif color == "red":
return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="red")
else:
return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
predict_btn.click(
fn=predict_survival,
inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
outputs=[output_text, output_color_indicator]
).then(
fn=update_output_style,
inputs=[output_text, output_color_indicator],
outputs=output_text
)
gr.Markdown(
"""
---
**Feature Definitions:**
* **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
* **Sex:** Sex (male/female)
* **Age:** Age in years
* **SibSp:** Number of siblings/spouses aboard the Titanic
* **Parch:** Number of parents/children aboard the Titanic
* **Fare:** Passenger fare
* **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
*Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
"""
)
demo.launch()