Abrar20's picture
Update app.py
ab317bb verified
raw
history blame
5.04 kB
import gradio as gr
import numpy as np
import pandas as pd
import folium
import joblib
from io import BytesIO
import base64
# Load the saved model
def load_model():
return joblib.load('cyclone_model.pkl')
# Predict the next 6 rows based on input
def predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present,
year_present, month_present, day_present, hour_present,
lat_prev, lon_prev, dist2land_prev, storm_speed_prev):
# Construct the current and previous data as feature arrays
current_data = [lat_present, lon_present, dist2land_present, storm_speed_present, year_present, month_present, day_present, hour_present]
previous_data = [lat_prev, lon_prev, dist2land_prev, storm_speed_prev, year_present, month_present, day_present, hour_present - 3]
# Adjust for negative hours
if previous_data[-1] < 0:
previous_data[-1] += 24
previous_data[6] -= 1 # Adjust day
# Preprocess the input into the required shape (2 rows, 8 columns)
input_data = [previous_data, current_data]
input_data = np.array(input_data).reshape(1, 2, 8)
# Flatten the input to match the model's input shape
input_data_flat = input_data.reshape(1, -1)
# Load the model and make predictions
loaded_model = load_model()
predictions = loaded_model.predict(input_data_flat)
# Reshape the predictions back to (6, 4) format
predictions_reshaped = predictions.reshape(6, 4)
# Create a DataFrame for the predictions
columns = ['LAT', 'LON', 'DIST2LAND', 'STORM_SPEED']
df_predictions = pd.DataFrame(predictions_reshaped, columns=columns)
# Add the 'Hour' column, incrementing by 3 hours from the present time
df_predictions['Hour'] = [(hour_present + (i + 1) * 3) % 24 for i in range(6)] # Ensure the hour wraps around 24
return df_predictions
# Plot predictions on a folium map and return an HTML iframe
def plot_predictions_on_map(df_predictions):
# Extract LAT and LON from the predictions
latitudes = df_predictions['LAT'].tolist()
longitudes = df_predictions['LON'].tolist()
# Create a folium map centered at the first predicted point
m = folium.Map(location=[latitudes[0], longitudes[0]], zoom_start=6)
# Add the predicted points to the map and connect them with a polyline
locations = list(zip(latitudes, longitudes))
# Add the points to the map
for lat, lon in locations:
folium.Marker([lat, lon]).add_to(m)
# Add a polyline to connect the points
folium.PolyLine(locations, color='blue', weight=2.5, opacity=0.7).add_to(m)
# Save map as HTML in a BytesIO object
map_html = BytesIO()
m.save(map_html)
map_html.seek(0)
map_base64 = base64.b64encode(map_html.getvalue()).decode('utf-8')
return f'<iframe src="data:text/html;base64,{map_base64}" width="100%" height="400"></iframe>'
# Gradio interface setup
def main_interface(lat_present, lon_present, dist2land_present, storm_speed_present,
year_present, month_present, day_present, hour_present,
lat_prev, lon_prev, dist2land_prev, storm_speed_prev):
# Get the DataFrame prediction
df_predictions = predict_next_6_rows(lat_present, lon_present, dist2land_present, storm_speed_present,
year_present, month_present, day_present, hour_present,
lat_prev, lon_prev, dist2land_prev, storm_speed_prev)
# Generate map
map_html = plot_predictions_on_map(df_predictions)
return df_predictions, map_html
# Gradio app
with gr.Blocks() as cyclone_predictor:
gr.Markdown("# Cyclone Path Prediction")
# Input fields
lat_present = gr.Number(label="Current Latitude")
lon_present = gr.Number(label="Current Longitude")
dist2land_present = gr.Number(label="Current DIST2LAND")
storm_speed_present = gr.Number(label="Current STORM_SPEED")
year_present = gr.Number(label="Current Year")
month_present = gr.Number(label="Current Month")
day_present = gr.Number(label="Current Day")
hour_present = gr.Number(label="Current Hour")
lat_prev = gr.Number(label="Previous Latitude")
lon_prev = gr.Number(label="Previous Longitude")
dist2land_prev = gr.Number(label="Previous DIST2LAND")
storm_speed_prev = gr.Number(label="Previous STORM_SPEED")
# Prediction and map output
prediction_output = gr.Dataframe(label="Predicted DataFrame")
map_output = gr.HTML(label="Predicted Path Map")
# Button to trigger prediction
predict_button = gr.Button("Predict")
predict_button.click(
main_interface,
inputs=[lat_present, lon_present, dist2land_present, storm_speed_present,
year_present, month_present, day_present, hour_present,
lat_prev, lon_prev, dist2land_prev, storm_speed_prev],
outputs=[prediction_output, map_output]
)
cyclone_predictor.launch()