File size: 6,206 Bytes
a822fa8
 
42a186f
362698c
 
 
a822fa8
68f838a
a822fa8
362698c
 
68f838a
362698c
 
 
68f838a
362698c
 
 
42a186f
362698c
 
42a186f
362698c
 
42a186f
362698c
 
a822fa8
362698c
 
 
42a186f
68f838a
362698c
 
 
42a186f
362698c
 
 
 
 
 
 
 
 
42a186f
 
 
 
362698c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a822fa8
42a186f
362698c
 
42a186f
362698c
 
 
 
42a186f
 
362698c
42a186f
a822fa8
 
362698c
 
a822fa8
 
362698c
 
 
 
 
 
 
 
 
 
 
 
42a186f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362698c
 
42a186f
 
 
 
 
 
a822fa8
 
 
 
362698c
 
 
 
 
 
 
 
 
 
a822fa8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import io # Keep io, though not strictly used in this version, it's harmless.

# --- Data Loading and Preprocessing ---

# Load the dataset named 'titanic.csv'
# Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
try:
    df = pd.read_csv('titanic.csv')
except FileNotFoundError:
    raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")

# Drop irrelevant columns and 'PassengerId' which is not a feature
# These columns are typically present in a standard Titanic dataset.
df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)

# Handle missing 'Age' with median imputation
df['Age'].fillna(df['Age'].median(), inplace=True)

# Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
df['Fare'].fillna(df['Fare'].median(), inplace=True)

# Handle missing 'Embarked' with mode imputation
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)

# Convert categorical features to numerical using one-hot encoding
# We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)


# Define features (X) and target (y)
X = df.drop('Survived', axis=1)
y = df['Survived']

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a RandomForestClassifier model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# --- Model Evaluation (for display in app) ---
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"

# --- Prediction Function for Gradio ---
def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
    # Create a dictionary for the input values
    input_dict = {
        'Pclass': pclass,
        'Age': age,
        'SibSp': sibsp,
        'Parch': parch,
        'Fare': fare,
        # These match the one-hot encoded columns created during training
        'Sex_male': 1 if sex == 'male' else 0,
        'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
        'Embarked_S': 1 if embarked == 'S' else 0  # Assuming 'S' is 'Embarked_S'
    }

    # Create a DataFrame from the input values
    input_data = pd.DataFrame([input_dict])

    # Ensure all columns expected by the model are present in the input_data, even if 0
    # This handles cases where a category might not be present in a single input but was in training
    for col in X.columns:
        if col not in input_data.columns:
            input_data[col] = 0

    # Reorder columns to match the training data's column order
    input_data = input_data[X.columns]

    # Make prediction
    prediction = model.predict(input_data)[0]
    prediction_proba = model.predict_proba(input_data)[0]

    if prediction == 1:
        return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
    else:
        return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"

# --- Gradio Interface ---
# CSS to style the output textbox background
with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
    gr.Markdown(
        """
        # Titanic Survival Predictor
        Enter passenger details to predict their survival on the Titanic.
        """
    )
    gr.Markdown(f"### Model Performance: {accuracy_message}")

    with gr.Row():
        pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
        sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
        age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
    with gr.Row():
        sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
        parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
        fare_input = gr.Number(label="Fare", value=30.0)
    with gr.Row():
        embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')

    predict_btn = gr.Button("Predict Survival")
    output_text = gr.Textbox(label="Survival Prediction", interactive=False)
    # This label is used internally to get the color, its content is not directly shown
    output_color_indicator = gr.Label(visible=False)

    # Function to update the textbox styling based on prediction
    def update_output_style(text, color):
        if color == "green":
            return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="green")
        elif color == "red":
            return gr.Textbox(value=text, label="Survival Prediction", interactive=False, elem_classes="red")
        else:
            return gr.Textbox(value=text, label="Survival Prediction", interactive=False)

    predict_btn.click(
        fn=predict_survival,
        inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
        outputs=[output_text, output_color_indicator]
    ).then(
        fn=update_output_style,
        inputs=[output_text, output_color_indicator],
        outputs=output_text
    )

    gr.Markdown(
        """
        ---
        **Feature Definitions:**
        * **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
        * **Sex:** Sex (male/female)
        * **Age:** Age in years
        * **SibSp:** Number of siblings/spouses aboard the Titanic
        * **Parch:** Number of parents/children aboard the Titanic
        * **Fare:** Passenger fare
        * **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)

        *Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
        """
    )

demo.launch()