resolverkatla commited on
Commit
64f61d8
·
verified ·
1 Parent(s): 755c66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -93
app.py CHANGED
@@ -1,116 +1,93 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
- from sklearn.ensemble import RandomForestClassifier
5
- from sklearn.metrics import accuracy_score
 
6
  from datasets import load_dataset # To load the dataset from Hugging Face
7
 
8
  # --- Data Loading and Preprocessing ---
9
 
10
- # Load the Titanic dataset from Hugging Face
11
- # This dataset is commonly available on HF and mirrors the Kaggle structure.
12
  try:
13
- # We load the 'train' split of the dataset
14
- dataset = load_dataset("AbubakarJ/titanic", split="train")
15
- df = pd.DataFrame(dataset) # Convert to pandas DataFrame
16
  except Exception as e:
17
- # If the dataset cannot be loaded (e.g., internet issue on Space startup, or dataset changed)
18
- # This will raise a RuntimeError which typically stops the Gradio app from launching
19
- raise RuntimeError(f"Could not load Titanic dataset from Hugging Face: {e}. "
20
- "Please check the dataset name or your connection.")
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Drop irrelevant columns and 'PassengerId' which is not a feature
23
- # These columns are typically present in the full Kaggle Titanic dataset.
24
- df = df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
25
 
26
- # Handle missing 'Age' with median imputation
27
- df['Age'].fillna(df['Age'].median(), inplace=True)
28
 
29
- # Handle missing 'Fare' with median imputation (Fare can also have missing values sometimes)
30
- df['Fare'].fillna(df['Fare'].median(), inplace=True)
 
31
 
32
- # Handle missing 'Embarked' with mode imputation
33
- df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
 
34
 
35
- # Convert categorical features to numerical using one-hot encoding
36
- # We drop 'Embarked_C' to avoid multicollinearity (as per common practice)
37
- df = pd.get_dummies(df, columns=['Sex', 'Embarked'], drop_first=True)
38
 
 
 
 
39
 
40
- # Define features (X) and target (y)
41
- X = df.drop('Survived', axis=1)
42
- y = df['Survived']
43
 
44
- # Split data into training and testing sets
45
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
46
-
47
- # Train a RandomForestClassifier model
48
- model = RandomForestClassifier(n_estimators=100, random_state=42)
49
- model.fit(X_train, y_train)
50
-
51
- # Evaluate the model (for display purposes)
52
- y_pred = model.predict(X_test)
53
  accuracy = accuracy_score(y_test, y_pred)
54
  accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
55
 
56
  # --- Prediction Function for Gradio ---
57
- def predict_survival(pclass, sex, age, sibsp, parch, fare, embarked):
58
- # Create a dictionary for the input values
59
- input_dict = {
60
- 'Pclass': pclass,
61
- 'Age': age,
62
- 'SibSp': sibsp,
63
- 'Parch': parch,
64
- 'Fare': fare,
65
- # These match the one-hot encoded columns created during training
66
- 'Sex_male': 1 if sex == 'male' else 0,
67
- 'Embarked_Q': 1 if embarked == 'Q' else 0, # Assuming 'Q' is 'Embarked_Q'
68
- 'Embarked_S': 1 if embarked == 'S' else 0 # Assuming 'S' is 'Embarked_S'
69
- }
70
-
71
- # Create a DataFrame from the input values
72
- input_data = pd.DataFrame([input_dict])
73
-
74
- # Ensure all columns expected by the model are present in the input_data, even if 0
75
- # This handles cases where a category might not be present in a single input but was in training
76
- for col in X.columns:
77
- if col not in input_data.columns:
78
- input_data[col] = 0
79
-
80
- # Reorder columns to match the training data's column order
81
- input_data = input_data[X.columns]
82
 
83
  # Make prediction
84
- prediction = model.predict(input_data)[0]
85
- prediction_proba = model.predict_proba(input_data)[0]
86
 
87
- if prediction == 1:
88
- return f"Prediction: Survived ({prediction_proba[1]:.2%} confidence)", "green"
89
- else:
90
- return f"Prediction: Did Not Survive ({prediction_proba[0]:.2%} confidence)", "red"
91
 
92
  # --- Gradio Interface ---
93
- # CSS to style the output textbox background
94
  with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
95
  gr.Markdown(
96
  """
97
- # Titanic Survival Predictor
98
- Enter passenger details to predict their survival on the Titanic.
 
99
  """
100
  )
101
- gr.Markdown(f"### Model Performance: {accuracy_message}")
102
-
103
- with gr.Row():
104
- pclass_input = gr.Radio(choices=[1, 2, 3], label="Pclass", value=3)
105
- sex_input = gr.Radio(choices=['male', 'female'], label="Sex", value='male')
106
- age_input = gr.Slider(minimum=0.5, maximum=80, value=30, label="Age", step=0.5)
107
- with gr.Row():
108
- sibsp_input = gr.Number(label="SibSp (Siblings/Spouses Aboard)", value=0)
109
- parch_input = gr.Number(label="Parch (Parents/Children Aboard)", value=0)
110
- fare_input = gr.Number(label="Fare", value=30.0)
111
- with gr.Row():
112
- embarked_input = gr.Radio(choices=['C', 'Q', 'S'], label="Embarked (Port of Embarkation)", value='S')
113
 
 
 
 
 
 
114
  predict_btn = gr.Button("Predict Survival")
115
  output_text = gr.Textbox(label="Survival Prediction", interactive=False)
116
  # This label is used internally to get the color, its content is not directly shown
@@ -126,8 +103,8 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
126
  return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
127
 
128
  predict_btn.click(
129
- fn=predict_survival,
130
- inputs=[pclass_input, sex_input, age_input, sibsp_input, parch_input, fare_input, embarked_input],
131
  outputs=[output_text, output_color_indicator]
132
  ).then(
133
  fn=update_output_style,
@@ -135,19 +112,25 @@ with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {backgrou
135
  outputs=output_text
136
  )
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  gr.Markdown(
139
  """
140
  ---
141
- **Feature Definitions:**
142
- * **Pclass:** Passenger Class (1 = 1st, 2 = 2nd, 3 = 3rd)
143
- * **Sex:** Sex (male/female)
144
- * **Age:** Age in years
145
- * **SibSp:** Number of siblings/spouses aboard the Titanic
146
- * **Parch:** Number of parents/children aboard the Titanic
147
- * **Fare:** Passenger fare
148
- * **Embarked:** Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
149
-
150
- *Note: The dataset is loaded directly from Hugging Face's `datasets` library ([AbubakarJ/titanic](https://huggingface.co/datasets/AbubakarJ/titanic)). Missing 'Age', 'Fare', and 'Embarked' values are imputed. Categorical features are one-hot encoded.*
151
  """
152
  )
153
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
+ from sklearn.feature_extraction.text import CountVectorizer
5
+ from sklearn.naive_bayes import MultinomialNB
6
+ from sklearn.metrics import accuracy_score, classification_report
7
  from datasets import load_dataset # To load the dataset from Hugging Face
8
 
9
  # --- Data Loading and Preprocessing ---
10
 
11
+ # Load the julien-c/titanic-survival dataset from Hugging Face
 
12
  try:
13
+ # This dataset typically contains 'text' (description) and 'label' (0=died, 1=survived)
14
+ dataset = load_dataset("julien-c/titanic-survival", split="train")
15
+ df = pd.DataFrame(dataset) # Convert the dataset to a pandas DataFrame
16
  except Exception as e:
17
+ gr.Warning(f"Failed to load dataset: {e}. Please check your internet connection or dataset name.")
18
+ # Provide a minimal fallback for local testing if HF dataset loading fails
19
+ df = pd.DataFrame({
20
+ 'text': [
21
+ "A young boy, probably a steerage passenger. He doesn't look like he survived.",
22
+ "A first-class lady with a child. She likely survived due to priority.",
23
+ "Male, 30s, middle class. Probably didn't make it.",
24
+ "Female, 20s, dressed finely. Looks like she got on a lifeboat.",
25
+ "An elderly man, alone, traveling steerage."
26
+ ],
27
+ 'label': ["0", "1", "0", "1", "0"] # 0 for died, 1 for survived
28
+ })
29
+
30
+
31
+ # Ensure 'label' column is numeric
32
+ df['label'] = pd.to_numeric(df['label'])
33
 
34
+ # Define features (X) and target (y)
35
+ X = df['text']
36
+ y = df['label']
37
 
38
+ # Split data into training and testing sets
39
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
40
 
41
+ # Initialize CountVectorizer
42
+ # This converts text documents to a matrix of token counts
43
+ vectorizer = CountVectorizer(stop_words='english', lowercase=True)
44
 
45
+ # Fit the vectorizer on the training data and transform both training and test data
46
+ X_train_vectorized = vectorizer.fit_transform(X_train)
47
+ X_test_vectorized = vectorizer.transform(X_test)
48
 
49
+ # --- Model Training ---
 
 
50
 
51
+ # Initialize and train the Multinomial Naive Bayes model
52
+ model = MultinomialNB()
53
+ model.fit(X_train_vectorized, y_train)
54
 
55
+ # --- Model Evaluation (for display in app) ---
 
 
56
 
57
+ y_pred = model.predict(X_test_vectorized)
 
 
 
 
 
 
 
 
58
  accuracy = accuracy_score(y_test, y_pred)
59
  accuracy_message = f"Model Accuracy on Test Set: {accuracy:.2f}"
60
 
61
  # --- Prediction Function for Gradio ---
62
+ def predict_survival_from_text(passenger_description):
63
+ # Transform the input description using the *trained* vectorizer
64
+ message_vectorized = vectorizer.transform([passenger_description])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Make prediction
67
+ prediction = model.predict(message_vectorized)[0]
68
+ prediction_proba = model.predict_proba(message_vectorized)[0]
69
 
70
+ if prediction == 1: # 1 corresponds to 'survived'
71
+ return f"Prediction: SURVIVED ({prediction_proba[1]:.2%} confidence)", "green"
72
+ else: # 0 corresponds to 'died'
73
+ return f"Prediction: DID NOT SURVIVE ({prediction_proba[0]:.2%} confidence)", "red"
74
 
75
  # --- Gradio Interface ---
 
76
  with gr.Blocks(css=".green {background-color: #e6ffe6 !important;}.red {background-color: #ffe6e6 !important;}") as demo:
77
  gr.Markdown(
78
  """
79
+ # Titanic Survival Predictor (Text-based)
80
+ Enter a textual description of a passenger to predict their survival on the Titanic.
81
+ This model uses text classification techniques.
82
  """
83
  )
84
+ gr.Markdown(f"### {accuracy_message}")
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ description_input = gr.Textbox(
87
+ label="Enter Passenger Description",
88
+ lines=5,
89
+ placeholder="e.g., 'A young woman from first class, traveling alone.'"
90
+ )
91
  predict_btn = gr.Button("Predict Survival")
92
  output_text = gr.Textbox(label="Survival Prediction", interactive=False)
93
  # This label is used internally to get the color, its content is not directly shown
 
103
  return gr.Textbox(value=text, label="Survival Prediction", interactive=False)
104
 
105
  predict_btn.click(
106
+ fn=predict_survival_from_text,
107
+ inputs=description_input,
108
  outputs=[output_text, output_color_indicator]
109
  ).then(
110
  fn=update_output_style,
 
112
  outputs=output_text
113
  )
114
 
115
+ gr.Examples(
116
+ examples=[
117
+ "A wealthy first-class woman with her child. She was probably on a lifeboat.",
118
+ "An old man, alone, traveling in third class. He likely did not survive.",
119
+ "A young male crew member.",
120
+ "A small child from steerage."
121
+ ],
122
+ inputs=description_input,
123
+ outputs=[output_text, output_color_indicator],
124
+ fn=predict_survival_from_text,
125
+ cache_examples=True # Caches the output for examples for faster loading
126
+ )
127
+
128
  gr.Markdown(
129
  """
130
  ---
131
+ *This model uses a Multinomial Naive Bayes classifier. It is trained on the 'text' descriptions from the
132
+ [julien-c/titanic-survival](https://huggingface.co/datasets/julien-c/titanic-survival) dataset
133
+ to predict survival ('0' for died, '1' for survived). Text is preprocessed using CountVectorizer.*
 
 
 
 
 
 
 
134
  """
135
  )
136