Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import folium | |
| import joblib | |
| from io import BytesIO | |
| import base64 | |
| # Load the saved model | |
| def load_model(): | |
| return joblib.load('cyclone_model.pkl') | |
| # Predict the next 6 rows based on input | |
| def predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present, | |
| year_present, month_present, day_present, hour_present, | |
| lat_prev, lon_prev, dist2land_prev, storm_speed_prev): | |
| # Construct the current and previous data as feature arrays | |
| current_data = [lat_present, lon_present, dist2land_present, storm_speed_present, year_present, month_present, day_present, hour_present] | |
| previous_data = [lat_prev, lon_prev, dist2land_prev, storm_speed_prev, year_present, month_present, day_present, hour_present - 3] | |
| # Adjust for negative hours | |
| if previous_data[-1] < 0: | |
| previous_data[-1] += 24 | |
| previous_data[6] -= 1 # Adjust day | |
| # Preprocess the input into the required shape (2 rows, 8 columns) | |
| input_data = [previous_data, current_data] | |
| input_data = np.array(input_data).reshape(1, 2, 8) | |
| # Flatten the input to match the model's input shape | |
| input_data_flat = input_data.reshape(1, -1) | |
| # Load the model and make predictions | |
| loaded_model = load_model() | |
| predictions = loaded_model.predict(input_data_flat) | |
| # Reshape the predictions back to (6, 4) format | |
| predictions_reshaped = predictions.reshape(6, 4) | |
| # Create a DataFrame for the predictions | |
| columns = ['LAT', 'LON', 'DIST2LAND', 'STORM_SPEED'] | |
| df_predictions = pd.DataFrame(predictions_reshaped, columns=columns) | |
| # Add the 'Hour' column, incrementing by 3 hours from the present time | |
| df_predictions['Hour'] = [(hour_present + (i + 1) * 3) % 24 for i in range(6)] # Ensure the hour wraps around 24 | |
| return df_predictions | |
| # Plot predictions on a folium map and return an HTML iframe | |
| def plot_predictions_on_map(df_predictions): | |
| # Extract LAT and LON from the predictions | |
| latitudes = df_predictions['LAT'].tolist() | |
| longitudes = df_predictions['LON'].tolist() | |
| # Create a folium map centered at the first predicted point | |
| m = folium.Map(location=[latitudes[0], longitudes[0]], zoom_start=6) | |
| # Add the predicted points to the map and connect them with a polyline | |
| locations = list(zip(latitudes, longitudes)) | |
| # Add the points to the map | |
| for lat, lon in locations: | |
| folium.Marker([lat, lon]).add_to(m) | |
| # Add a polyline to connect the points | |
| folium.PolyLine(locations, color='blue', weight=2.5, opacity=0.7).add_to(m) | |
| # Save map as HTML in a BytesIO object | |
| map_html = BytesIO() | |
| m.save(map_html) | |
| map_html.seek(0) | |
| map_base64 = base64.b64encode(map_html.getvalue()).decode('utf-8') | |
| return f'<iframe src="data:text/html;base64,{map_base64}" width="100%" height="400"></iframe>' | |
| # Gradio interface setup | |
| def main_interface(lat_present, lon_present, dist2land_present, storm_speed_present, | |
| year_present, month_present, day_present, hour_present, | |
| lat_prev, lon_prev, dist2land_prev, storm_speed_prev): | |
| # Get the DataFrame prediction | |
| df_predictions = predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present, | |
| year_present, month_present, day_present, hour_present, | |
| lat_prev, lon_prev, dist2land_prev, storm_speed_prev) | |
| # Generate map | |
| map_html = plot_predictions_on_map(df_predictions) | |
| return df_predictions, map_html | |
| # Gradio app | |
| with gr.Blocks() as cyclone_predictor: | |
| gr.Markdown("# Cyclone Path Prediction") | |
| # Input fields | |
| lat_present = gr.Number(label="Current Latitude") | |
| lon_present = gr.Number(label="Current Longitude") | |
| dist2land_present = gr.Number(label="Current DIST2LAND") | |
| storm_speed_present = gr.Number(label="Current STORM_SPEED") | |
| year_present = gr.Number(label="Current Year") | |
| month_present = gr.Number(label="Current Month") | |
| day_present = gr.Number(label="Current Day") | |
| hour_present = gr.Number(label="Current Hour") | |
| lat_prev = gr.Number(label="Previous Latitude") | |
| lon_prev = gr.Number(label="Previous Longitude") | |
| dist2land_prev = gr.Number(label="Previous DIST2LAND") | |
| storm_speed_prev = gr.Number(label="Previous STORM_SPEED") | |
| # Prediction and map output | |
| prediction_output = gr.Dataframe(label="Predicted DataFrame") | |
| map_output = gr.HTML(label="Predicted Path Map") | |
| # Button to trigger prediction | |
| predict_button = gr.Button("Predict") | |
| predict_button.click( | |
| main_interface, | |
| inputs=[lat_present, lon_present, dist2land_present, storm_speed_present, | |
| year_present, month_present, day_present, hour_present, | |
| lat_prev, lon_prev, dist2land_prev, storm_speed_prev], | |
| outputs=[prediction_output, map_output] | |
| ) | |
| cyclone_predictor.launch() | |