Abrar20 commited on
Commit
cfbe71a
·
verified ·
1 Parent(s): eb92633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -17,6 +17,11 @@ model_paths = {
17
  '30 hours': 'lr_30H_lat_lon.pkl',
18
  '33 hours': 'lr_33H_lat_lon.pkl',
19
  '36 hours': 'lr_36H_lat_lon.pkl'
 
 
 
 
 
20
  }
21
  }
22
 
@@ -29,17 +34,34 @@ scaler_paths = {
29
  '12 hours': 'lr_12H_lat_lon_scaler.pkl',
30
  '15 hours': 'lr_15H_lat_lon_scaler.pkl',
31
  '18 hours': 'lr_18H_lat_lon_scaler.pkl',
 
32
  '24 hours': 'lr_24H_lat_lon_scaler.pkl',
33
  '27 hours': 'lr_27H_lat_lon_scaler.pkl',
34
  '30 hours': 'lr_30H_lat_lon_scaler.pkl',
35
  '33 hours': 'lr_33H_lat_lon_scaler.pkl',
36
  '36 hours': 'lr_36H_lat_lon_scaler.pkl'
 
 
 
 
 
37
  }
38
  }
39
 
40
- def process_input(input_data, scaler):
 
 
 
 
 
 
41
  input_data = np.array(input_data).reshape(-1, 7)
42
- processed_data = input_data[:2].reshape(1, -1)
 
 
 
 
 
43
  processed_data = scaler.transform(processed_data)
44
  return processed_data
45
 
@@ -50,31 +72,44 @@ def load_model_and_predict(prediction_type, time_interval, input_data):
50
  scaler = joblib.load(scaler_paths[prediction_type][time_interval])
51
 
52
  # Process input and predict
53
- processed_data = process_input(input_data, scaler)
54
  prediction = model.predict(processed_data)
55
 
56
  if prediction_type == 'Path':
57
  return f"Predicted Path after {time_interval}: Latitude: {prediction[0][0]}, Longitude: {prediction[0][1]}"
 
 
58
  except Exception as e:
59
  return str(e)
60
 
61
  # Gradio interface components
62
  with gr.Blocks() as cyclone_predictor:
63
- gr.Markdown("# Cyclone Path Prediction App")
64
 
65
  # Dropdown for Prediction Type
66
  prediction_type = gr.Dropdown(
67
- choices=['Path'],
68
  value='Path',
69
  label="Select Prediction Type"
70
  )
71
 
72
  # Dropdown for Time Interval
73
  time_interval = gr.Dropdown(
74
- choices=['3 hours', '6 hours', '9 hours', '12 hours', '15 hours', '18 hours', '21 hours', '24 hours', '27 hours', '30 hours', '33 hours', '36 hours'],
75
  label="Select Time Interval"
76
  )
77
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Input fields for user data
79
  previous_lat_lon = gr.Textbox(
80
  placeholder="Enter previous 3-hour lat/lon (e.g., 15.54,90.64)",
@@ -115,7 +150,7 @@ with gr.Blocks() as cyclone_predictor:
115
  except Exception as e:
116
  return str(e)
117
 
118
- predict_button = gr.Button("Predict Path")
119
 
120
  # Linking function to UI elements
121
  predict_button.click(
 
17
  '30 hours': 'lr_30H_lat_lon.pkl',
18
  '33 hours': 'lr_33H_lat_lon.pkl',
19
  '36 hours': 'lr_36H_lat_lon.pkl'
20
+ },
21
+ 'Speed': {
22
+ '3 hours': 'Igbm_3H_speed.pkl',
23
+ '15 hours': 'Igbm_15H_speed.pkl',
24
+ '27 hours': 'Igbm_27H_speed.pkl'
25
  }
26
  }
27
 
 
34
  '12 hours': 'lr_12H_lat_lon_scaler.pkl',
35
  '15 hours': 'lr_15H_lat_lon_scaler.pkl',
36
  '18 hours': 'lr_18H_lat_lon_scaler.pkl',
37
+ '21 hours': 'lr_21H_lat_lon_scaler.pkl',
38
  '24 hours': 'lr_24H_lat_lon_scaler.pkl',
39
  '27 hours': 'lr_27H_lat_lon_scaler.pkl',
40
  '30 hours': 'lr_30H_lat_lon_scaler.pkl',
41
  '33 hours': 'lr_33H_lat_lon_scaler.pkl',
42
  '36 hours': 'lr_36H_lat_lon_scaler.pkl'
43
+ },
44
+ 'Speed': {
45
+ '3 hours': 'Igbm_speed_scaler_3H.pkl',
46
+ '15 hours': 'Igbm_speed_scaler_15H.pkl',
47
+ '27 hours': 'Igbm_speed_scaler_27H.pkl'
48
  }
49
  }
50
 
51
+ # Define time intervals for each prediction type
52
+ time_intervals = {
53
+ 'Path': ['3 hours', '6 hours', '9 hours', '12 hours', '15 hours', '18 hours', '21 hours', '24 hours', '27 hours', '30 hours', '33 hours', '36 hours'],
54
+ 'Speed': ['3 hours', '15 hours', '27 hours']
55
+ }
56
+
57
+ def process_input(input_data, scaler, prediction_type):
58
  input_data = np.array(input_data).reshape(-1, 7)
59
+ if prediction_type == 'Speed':
60
+ # For speed prediction, reshape accordingly
61
+ input_data = input_data[:2].reshape(1, 2, 7)
62
+ processed_data = input_data.reshape(-1, 14)
63
+ else: # Path
64
+ processed_data = input_data[:2].reshape(1, -1)
65
  processed_data = scaler.transform(processed_data)
66
  return processed_data
67
 
 
72
  scaler = joblib.load(scaler_paths[prediction_type][time_interval])
73
 
74
  # Process input and predict
75
+ processed_data = process_input(input_data, scaler, prediction_type)
76
  prediction = model.predict(processed_data)
77
 
78
  if prediction_type == 'Path':
79
  return f"Predicted Path after {time_interval}: Latitude: {prediction[0][0]}, Longitude: {prediction[0][1]}"
80
+ elif prediction_type == 'Speed':
81
+ return f"Predicted Speed after {time_interval}: {prediction[0]}"
82
  except Exception as e:
83
  return str(e)
84
 
85
  # Gradio interface components
86
  with gr.Blocks() as cyclone_predictor:
87
+ gr.Markdown("# Cyclone Path and Speed Prediction App")
88
 
89
  # Dropdown for Prediction Type
90
  prediction_type = gr.Dropdown(
91
+ choices=['Path', 'Speed'],
92
  value='Path',
93
  label="Select Prediction Type"
94
  )
95
 
96
  # Dropdown for Time Interval
97
  time_interval = gr.Dropdown(
98
+ choices=time_intervals['Path'],
99
  label="Select Time Interval"
100
  )
101
 
102
+ # Function to update time intervals based on prediction type
103
+ def update_time_intervals(prediction_type):
104
+ return gr.Dropdown.update(choices=time_intervals[prediction_type])
105
+
106
+ # Update time intervals when prediction type changes
107
+ prediction_type.change(
108
+ fn=update_time_intervals,
109
+ inputs=prediction_type,
110
+ outputs=time_interval
111
+ )
112
+
113
  # Input fields for user data
114
  previous_lat_lon = gr.Textbox(
115
  placeholder="Enter previous 3-hour lat/lon (e.g., 15.54,90.64)",
 
150
  except Exception as e:
151
  return str(e)
152
 
153
+ predict_button = gr.Button("Predict")
154
 
155
  # Linking function to UI elements
156
  predict_button.click(