Update app.py
Browse files
app.py
CHANGED
@@ -1,144 +1,125 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
import folium
|
|
|
5 |
from io import BytesIO
|
6 |
import base64
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
'
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
'18 hours': 'lr_18H_lat_lon.pkl',
|
17 |
-
'21 hours': 'lr_21H_lat_lon.pkl',
|
18 |
-
'24 hours': 'lr_24H_lat_lon.pkl',
|
19 |
-
'27 hours': 'lr_27H_lat_lon.pkl',
|
20 |
-
'30 hours': 'lr_30H_lat_lon.pkl',
|
21 |
-
'33 hours': 'lr_33H_lat_lon.pkl',
|
22 |
-
'36 hours': 'lr_36H_lat_lon.pkl'
|
23 |
-
}
|
24 |
-
}
|
25 |
-
|
26 |
-
# Define scaler paths
|
27 |
-
scaler_paths = {
|
28 |
-
'Path': {
|
29 |
-
'3 hours': 'lr_3H_lat_lon_scaler.pkl',
|
30 |
-
'6 hours': 'lr_6H_lat_lon_scaler.pkl',
|
31 |
-
'9 hours': 'lr_9H_lat_lon_scaler.pkl',
|
32 |
-
'12 hours': 'lr_12H_lat_lon_scaler.pkl',
|
33 |
-
'15 hours': 'lr_15H_lat_lon_scaler.pkl',
|
34 |
-
'18 hours': 'lr_18H_lat_lon_scaler.pkl',
|
35 |
-
'24 hours': 'lr_24H_lat_lon_scaler.pkl',
|
36 |
-
'27 hours': 'lr_27H_lat_lon_scaler.pkl',
|
37 |
-
'30 hours': 'lr_30H_lat_lon_scaler.pkl',
|
38 |
-
'33 hours': 'lr_33H_lat_lon_scaler.pkl',
|
39 |
-
'36 hours': 'lr_36H_lat_lon_scaler.pkl'
|
40 |
-
}
|
41 |
-
}
|
42 |
-
|
43 |
-
def process_input(input_data, scaler):
|
44 |
-
input_data = np.array(input_data).reshape(-1, 7)
|
45 |
-
processed_data = input_data[:2].reshape(1, -1)
|
46 |
-
processed_data = scaler.transform(processed_data)
|
47 |
-
return processed_data
|
48 |
-
|
49 |
-
def load_model_and_predict(prediction_type, time_interval, input_data):
|
50 |
-
try:
|
51 |
-
# Load the model and scaler based on user selection
|
52 |
-
model = joblib.load(model_paths[prediction_type][time_interval])
|
53 |
-
scaler = joblib.load(scaler_paths[prediction_type][time_interval])
|
54 |
-
|
55 |
-
# Process input and predict
|
56 |
-
processed_data = process_input(input_data, scaler)
|
57 |
-
prediction = model.predict(processed_data)
|
58 |
-
|
59 |
-
lat, lon = prediction[0][0], prediction[0][1]
|
60 |
-
|
61 |
-
# Create Folium map for predicted location
|
62 |
-
map_ = folium.Map(location=[lat, lon], zoom_start=6)
|
63 |
-
folium.Marker([lat, lon], popup=f"Predicted Location ({lat:.2f}, {lon:.2f})").add_to(map_)
|
64 |
-
|
65 |
-
# Save map as HTML and convert to base64
|
66 |
-
map_html = BytesIO()
|
67 |
-
map_.save(map_html) # removed 'format' argument
|
68 |
-
map_html.seek(0)
|
69 |
-
map_base64 = base64.b64encode(map_html.getvalue()).decode("utf-8")
|
70 |
-
|
71 |
-
return f"Predicted Path after {time_interval}: Latitude: {lat}, Longitude: {lon}", f'<iframe src="data:text/html;base64,{map_base64}" width="100%" height="400"></iframe>'
|
72 |
-
except Exception as e:
|
73 |
-
return str(e), None
|
74 |
-
|
75 |
-
# Gradio interface components
|
76 |
-
with gr.Blocks() as cyclone_predictor:
|
77 |
-
gr.Markdown("# Cyclone Path Prediction App")
|
78 |
|
79 |
-
#
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
#
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
previous_lat_lon = gr.Textbox(
|
94 |
-
placeholder="Enter previous 3-hour lat/lon (e.g., 15.54,90.64)",
|
95 |
-
label="Previous 3-hour Latitude/Longitude"
|
96 |
-
)
|
97 |
-
previous_speed = gr.Number(label="Previous 3-hour Speed")
|
98 |
-
previous_timestamp = gr.Textbox(
|
99 |
-
placeholder="Enter previous 3-hour timestamp (e.g., 2024,10,23,0)",
|
100 |
-
label="Previous 3-hour Timestamp (year, month, day, hour)"
|
101 |
-
)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
#
|
118 |
-
|
119 |
-
|
120 |
-
# Parse inputs into required format
|
121 |
-
prev_lat, prev_lon = map(float, previous_lat_lon.split(','))
|
122 |
-
prev_time = list(map(int, previous_timestamp.split(',')))
|
123 |
-
previous_data = [prev_lat, prev_lon, previous_speed] + prev_time
|
124 |
-
|
125 |
-
present_lat, present_lon = map(float, present_lat_lon.split(','))
|
126 |
-
present_time = list(map(int, present_timestamp.split(',')))
|
127 |
-
present_data = [present_lat, present_lon, present_speed] + present_time
|
128 |
-
|
129 |
-
return [previous_data, present_data]
|
130 |
-
except Exception as e:
|
131 |
-
return str(e)
|
132 |
-
|
133 |
-
predict_button = gr.Button("Predict Path")
|
134 |
|
135 |
-
#
|
|
|
136 |
predict_button.click(
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
outputs=[prediction_output, map_output]
|
142 |
)
|
143 |
|
144 |
-
cyclone_predictor.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
import folium
|
5 |
+
import joblib
|
6 |
from io import BytesIO
|
7 |
import base64
|
8 |
|
9 |
+
# Load the saved model
|
10 |
+
def load_model():
|
11 |
+
return joblib.load('cyclone_model.pkl')
|
12 |
+
|
13 |
+
# Predict the next 6 rows based on input
|
14 |
+
def predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present,
|
15 |
+
year_present, month_present, day_present, hour_present,
|
16 |
+
lat_prev, lon_prev, dist2land_prev, storm_speed_prev):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# Construct the current and previous data as feature arrays
|
19 |
+
current_data = [lat_present, lon_present, dist2land_present, storm_speed_present, year_present, month_present, day_present, hour_present]
|
20 |
+
previous_data = [lat_prev, lon_prev, dist2land_prev, storm_speed_prev, year_present, month_present, day_present, hour_present - 3]
|
21 |
+
|
22 |
+
# Adjust for negative hours
|
23 |
+
if previous_data[-1] < 0:
|
24 |
+
previous_data[-1] += 24
|
25 |
+
previous_data[6] -= 1 # Adjust day
|
26 |
+
|
27 |
+
# Preprocess the input into the required shape (2 rows, 8 columns)
|
28 |
+
input_data = [previous_data, current_data]
|
29 |
+
input_data = np.array(input_data).reshape(1, 2, 8)
|
30 |
|
31 |
+
# Flatten the input to match the model's input shape
|
32 |
+
input_data_flat = input_data.reshape(1, -1)
|
33 |
+
|
34 |
+
# Load the model and make predictions
|
35 |
+
loaded_model = load_model()
|
36 |
+
predictions = loaded_model.predict(input_data_flat)
|
37 |
+
|
38 |
+
# Reshape the predictions back to (6, 4) format
|
39 |
+
predictions_reshaped = predictions.reshape(6, 4)
|
40 |
+
|
41 |
+
# Create a DataFrame for the predictions
|
42 |
+
columns = ['LAT', 'LON', 'DIST2LAND', 'STORM_SPEED']
|
43 |
+
df_predictions = pd.DataFrame(predictions_reshaped, columns=columns)
|
44 |
+
|
45 |
+
# Add the 'Hour' column, incrementing by 3 hours from the present time
|
46 |
+
df_predictions['Hour'] = [(hour_present + (i + 1) * 3) % 24 for i in range(6)] # Ensure the hour wraps around 24
|
47 |
+
|
48 |
+
return df_predictions
|
49 |
+
|
50 |
+
# Plot predictions on a folium map and return an HTML iframe
|
51 |
+
def plot_predictions_on_map(df_predictions):
|
52 |
+
# Extract LAT and LON from the predictions
|
53 |
+
latitudes = df_predictions['LAT'].tolist()
|
54 |
+
longitudes = df_predictions['LON'].tolist()
|
55 |
+
|
56 |
+
# Create a folium map centered at the first predicted point
|
57 |
+
m = folium.Map(location=[latitudes[0], longitudes[0]], zoom_start=6)
|
58 |
+
|
59 |
+
# Add the predicted points to the map and connect them with a polyline
|
60 |
+
locations = list(zip(latitudes, longitudes))
|
61 |
+
|
62 |
+
# Add the points to the map
|
63 |
+
for lat, lon in locations:
|
64 |
+
folium.Marker([lat, lon]).add_to(m)
|
65 |
+
|
66 |
+
# Add a polyline to connect the points
|
67 |
+
folium.PolyLine(locations, color='blue', weight=2.5, opacity=0.7).add_to(m)
|
68 |
+
|
69 |
+
# Save map as HTML in a BytesIO object
|
70 |
+
map_html = BytesIO()
|
71 |
+
m.save(map_html)
|
72 |
+
map_html.seek(0)
|
73 |
+
map_base64 = base64.b64encode(map_html.getvalue()).decode('utf-8')
|
74 |
|
75 |
+
return f'<iframe src="data:text/html;base64,{map_base64}" width="100%" height="400"></iframe>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
# Gradio interface setup
|
78 |
+
def main_interface(lat_present, lon_present, dist2land_present, storm_speed_present,
|
79 |
+
year_present, month_present, day_present, hour_present,
|
80 |
+
lat_prev, lon_prev, dist2land_prev, storm_speed_prev):
|
81 |
+
|
82 |
+
# Get the DataFrame prediction
|
83 |
+
df_predictions = predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present,
|
84 |
+
year_present, month_present, day_present, hour_present,
|
85 |
+
lat_prev, lon_prev, dist2land_prev, storm_speed_prev)
|
86 |
+
|
87 |
+
# Generate map
|
88 |
+
map_html = plot_predictions_on_map(df_predictions)
|
89 |
+
|
90 |
+
return df_predictions, map_html
|
91 |
|
92 |
+
# Gradio app
|
93 |
+
with gr.Blocks() as cyclone_predictor:
|
94 |
+
gr.Markdown("# Cyclone Path Prediction")
|
95 |
+
|
96 |
+
# Input fields
|
97 |
+
lat_present = gr.Number(label="Current Latitude")
|
98 |
+
lon_present = gr.Number(label="Current Longitude")
|
99 |
+
dist2land_present = gr.Number(label="Current DIST2LAND")
|
100 |
+
storm_speed_present = gr.Number(label="Current STORM_SPEED")
|
101 |
+
year_present = gr.Number(label="Current Year")
|
102 |
+
month_present = gr.Number(label="Current Month")
|
103 |
+
day_present = gr.Number(label="Current Day")
|
104 |
+
hour_present = gr.Number(label="Current Hour")
|
105 |
+
|
106 |
+
lat_prev = gr.Number(label="Previous Latitude")
|
107 |
+
lon_prev = gr.Number(label="Previous Longitude")
|
108 |
+
dist2land_prev = gr.Number(label="Previous DIST2LAND")
|
109 |
+
storm_speed_prev = gr.Number(label="Previous STORM_SPEED")
|
110 |
|
111 |
+
# Prediction and map output
|
112 |
+
prediction_output = gr.Dataframe(label="Predicted DataFrame")
|
113 |
+
map_output = gr.HTML(label="Predicted Path Map")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
+
# Button to trigger prediction
|
116 |
+
predict_button = gr.Button("Predict")
|
117 |
predict_button.click(
|
118 |
+
main_interface,
|
119 |
+
inputs=[lat_present, lon_present, dist2land_present, storm_speed_present,
|
120 |
+
year_present, month_present, day_present, hour_present,
|
121 |
+
lat_prev, lon_prev, dist2land_prev, storm_speed_prev],
|
122 |
outputs=[prediction_output, map_output]
|
123 |
)
|
124 |
|
125 |
+
cyclone_predictor.launch()
|