resolverkatla commited on
Commit
362698c
·
verified ·
1 Parent(s): 64f61d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -78
app.py CHANGED
@@ -1,93 +1,111 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
- from sklearn.feature_extraction.text import CountVectorizer
5
- from sklearn.naive_bayes import MultinomialNB
6
- from sklearn.metrics import accuracy_score, classification_report
7
- from datasets import load_dataset # To load the dataset from Hugging Face
8
 
9
  # --- Data Loading and Preprocessing ---
10
 
11
- # Load the julien-c/titanic-survival dataset from Hugging Face
 
12
  try:
13
- # This dataset typically contains 'text' (description) and 'label' (0=died, 1=survived)
14
- dataset = load_dataset("julien-c/titanic-survival", split="train")
15
- df = pd.DataFrame(dataset) # Convert the dataset to a pandas DataFrame
16
- except Exception as e:
17
- gr.Warning(f"Failed to load dataset: {e}. Please check your internet connection or dataset name.")
18
- # Provide a minimal fallback for local testing if HF dataset loading fails
19
- df = pd.DataFrame({
20
- 'text': [
21
- "A young boy, probably a steerage passenger. He doesn't look like he survived.",
22
- "A first-class lady with a child. She likely survived due to priority.",
23
- "Male, 30s, middle class. Probably didn't make it.",
24
- "Female, 20s, dressed finely. Looks like she got on a lifeboat.",
25
- "An elderly man, alone, traveling steerage."
26
- ],
27
- 'label': ["0", "1", "0", "1", "0"] # 0 for died, 1 for survived
28
- })
29
-
30
-
31
- # Ensure 'label' column is numeric
32
- df['label'] = pd.to_numeric(df['label'])
33
 
34
- # Define features (X) and target (y)
35
- X = df['text']
36
- y = df['label']
37
 
38
- # Split data into training and testing sets
39
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
40
 
41
- # Initialize CountVectorizer
42
- # This converts text documents to a matrix of token counts
43
- vectorizer = CountVectorizer(stop_words='english', lowercase=True)
44
 
45
- # Fit the vectorizer on the training data and transform both training and test data
46
- X_train_vectorized = vectorizer.fit_transform(X_train)
47
- X_test_vectorized = vectorizer.transform(X_test)
48
 
49
- # --- Model Training ---
 
 
50
 
51
- # Initialize and train the Multinomial Naive Bayes model
52
- model = MultinomialNB()
53
- model.fit(X_train_vectorized, y_train)
54
 
55
- # --- Model Evaluation (for display in app) ---
 
 
56
 
57
- y_pred = model.predict(X_test_vectorized)
 
 
 
 
 
 
 
 
58
  accuracy = accuracy_score(y_test, y_pred)
59
  accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
60
 
61
  # --- Prediction Function for Gradio ---
62
- def predict_survival_from_text(passenger_description):
63
- # Transform the input description using the *trained* vectorizer
64
- message_vectorized = vectorizer.transform([passenger_description])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Make prediction
67
- prediction = model.predict(message_vectorized)[0]
68
- prediction_proba = model.predict_proba(message_vectorized)[0]
69
 
70
- if prediction == 1: # 1 corresponds to 'survived'
71
- return f"Prediction: SURVIVED ({prediction_proba[1]:.2%} confidence)", "green"
72
- else: # 0 corresponds to 'died'
73
- return f"Prediction: DID NOT SURVIVE ({prediction_proba[0]:.2%} confidence)", "red"
74
 
75
  # --- Gradio Interface ---
 
76
  with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
77
  gr.Markdown(
78
  """
79
- # Titanic Survival Predictor (Text-based)
80
- Enter a textual description of a passenger to predict their survival on the Titanic.
81
- This model uses text classification techniques.
82
  """
83
  )
84
- gr.Markdown(f"### {accuracy_message}")
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- description_input = gr.Textbox(
87
- label="Enter Passenger Description",
88
- lines=5,
89
- placeholder="e.g., 'A young woman from first class, traveling alone.'"
90
- )
91
  predict_btn = gr.Button("Predict Survival")
92
  output_text = gr.Textbox(label="Survival Prediction", interactive=False)
93
  # This label is used internally to get the color, its content is not directly shown
@@ -103,8 +121,8 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
103
  return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
104
 
105
  predict_btn.click(
106
- fn=predict_survival_from_text,
107
- inputs=description_input,
108
  outputs=[output_text, output_color_indicator]
109
  ).then(
110
  fn=update_output_style,
@@ -112,25 +130,19 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
112
  outputs=output_text
113
  )
114
 
115
- gr.Examples(
116
- examples=[
117
- "A wealthy first-class woman with her child. She was probably on a lifeboat.",
118
- "An old man, alone, traveling in third class. He likely did not survive.",
119
- "A young male crew member.",
120
- "A small child from steerage."
121
- ],
122
- inputs=description_input,
123
- outputs=[output_text, output_color_indicator],
124
- fn=predict_survival_from_text,
125
- cache_examples=True # Caches the output for examples for faster loading
126
- )
127
-
128
  gr.Markdown(
129
  """
130
  ---
131
- *This model uses a Multinomial Naive Bayes classifier. It is trained on the 'text' descriptions from the
132
- [julien-c/titanic-survival](https://huggingface.co/datasets/julien-c/titanic-survival) dataset
133
- to predict survival ('0' for died, '1' for survived). Text is preprocessed using CountVectorizer.*
 
 
 
 
 
 
 
134
  """
135
  )
136
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
+ from sklearn.ensemble import RandomForestClassifier
5
+ from sklearn.metrics import accuracy_score
6
+ import io # Keep io, though not strictly used in this version, it's harmless.
 
7
 
8
  # --- Data Loading and Preprocessing ---
9
 
10
+ # Load the dataset named 'titanic.csv'
11
+ # Make sure 'titanic.csv' is uploaded to your Hugging Face Space or is in the same directory
12
  try:
13
+ df = pd.read_csv('titanic.csv')
14
+ except FileNotFoundError:
15
+ raise FileNotFoundError("titanic.csv not found. Please ensure it's downloaded and named 'titanic.csv', then uploaded to your Hugging Face Space.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Drop irrelevant columns and 'PassengerId' which is not a feature
18
+ # These columns are typically present in a standard Titanic dataset.
19
+ df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
20
 
21
+ # Handle missing 'Age' with median imputation
22
+ df['Age'].fillna(df['Age'].median(), inplace=True)
23
 
24
+ # Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
25
+ df['Fare'].fillna(df['Fare'].median(), inplace=True)
 
26
 
27
+ # Handle missing 'Embarked' with mode imputation
28
+ df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
 
29
 
30
+ # Convert categorical features to numerical using one-hot encoding
31
+ # We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
32
+ df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
33
 
 
 
 
34
 
35
+ # Define features (X) and target (y)
36
+ X = df.drop('Survived', axis=1)
37
+ y = df['Survived']
38
 
39
+ # Split data into training and testing sets
40
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
41
+
42
+ # Train a RandomForestClassifier model
43
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
44
+ model.fit(X_train, y_train)
45
+
46
+ # --- Model Evaluation (for display in app) ---
47
+ y_pred = model.predict(X_test)
48
  accuracy = accuracy_score(y_test, y_pred)
49
  accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
50
 
51
  # --- Prediction Function for Gradio ---
52
+ def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
53
+ # Create a dictionary for the input values
54
+ input_dict = {
55
+ 'Pclass': pclass,
56
+ 'Age': age,
57
+ 'SibSp': sibsp,
58
+ 'Parch': parch,
59
+ 'Fare': fare,
60
+ # These match the one-hot encoded columns created during training
61
+ 'Sex_male': 1 if sex == 'male' else 0,
62
+ 'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
63
+ 'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
64
+ }
65
+
66
+ # Create a DataFrame from the input values
67
+ input_data = pd.DataFrame([input_dict])
68
+
69
+ # Ensure all columns expected by the model are present in the input_data, even if 0
70
+ # This handles cases where a category might not be present in a single input but was in training
71
+ for col in X.columns:
72
+ if col not in input_data.columns:
73
+ input_data[col] = 0
74
+
75
+ # Reorder columns to match the training data's column order
76
+ input_data = input_data[X.columns]
77
 
78
  # Make prediction
79
+ prediction = model.predict(input_data)[0]
80
+ prediction_proba = model.predict_proba(input_data)[0]
81
 
82
+ if prediction == 1:
83
+ return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
84
+ else:
85
+ return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"
86
 
87
  # --- Gradio Interface ---
88
+ # CSS to style the output textbox background
89
  with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
90
  gr.Markdown(
91
  """
92
+ # Titanic Survival Predictor
93
+ Enter passenger details to predict their survival on the Titanic.
 
94
  """
95
  )
96
+ gr.Markdown(f"### Model Performance: {accuracy_message}")
97
+
98
+ with gr.Row():
99
+ pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
100
+ sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
101
+ age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
102
+ with gr.Row():
103
+ sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
104
+ parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
105
+ fare_input = gr.Number(label="Fare", value=30.0)
106
+ with gr.Row():
107
+ embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
108
 
 
 
 
 
 
109
  predict_btn = gr.Button("Predict Survival")
110
  output_text = gr.Textbox(label="Survival Prediction", interactive=False)
111
  # This label is used internally to get the color, its content is not directly shown
 
121
  return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
122
 
123
  predict_btn.click(
124
+ fn=predict_survival,
125
+ inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
126
  outputs=[output_text, output_color_indicator]
127
  ).then(
128
  fn=update_output_style,
 
130
  outputs=output_text
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  gr.Markdown(
134
  """
135
  ---
136
+ **Feature Definitions:**
137
+ * **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
138
+ * **Sex:** Sex (male/female)
139
+ * **Age:** Age in years
140
+ * **SibSp:** Number of siblings/spouses aboard the Titanic
141
+ * **Parch:** Number of parents/children aboard the Titanic
142
+ * **Fare:** Passenger fare
143
+ * **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
144
+
145
+ *Note: This app expects a `titanic.csv` file. Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
146
  """
147
  )
148