from typing import Callable, TypedDict
from matplotlib.figure import figaspect
import pandas as pd
from plotly.graph_objects import Figure
import plotly.graph_objects as go
import plotly.express as px
from climateqa.engine.talk_to_data.sql_query import (
indicator_for_given_year_query,
indicator_per_year_at_location_query,
)
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
class Plot(TypedDict):
"""Represents a plot configuration in the DRIAS system.
This class defines the structure for configuring different types of plots
that can be generated from climate data.
Attributes:
name (str): The name of the plot type
description (str): A description of what the plot shows
params (list[str]): List of required parameters for the plot
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
"""
name: str
description: str
params: list[str]
plot_function: Callable[..., Callable[..., Figure]]
sql_query: Callable[..., str]
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
"""Generates a function to plot indicator evolution over time at a location.
This function creates a line plot showing how a climate indicator changes
over time at a specific location. It handles temperature, precipitation,
and other climate indicators.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- location (str): The location to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
Example:
>>> plot_func = plot_indicator_evolution_at_location({
... 'indicator_column': 'mean_temperature',
... 'location': 'Paris',
... 'model': 'ALL'
... })
>>> fig = plot_func(df)
"""
indicator = params["indicator_column"]
location = params["location"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generates the actual plot from the data.
Args:
df (pd.DataFrame): DataFrame containing the data to plot
Returns:
Figure: A plotly Figure object showing the indicator evolution
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
# Compute the 10-year rolling average
rolling_window = 10
sliding_averages = (
df_avg[indicator]
.rolling(window=rolling_window, min_periods=rolling_window)
.mean()
.astype(float)
.tolist()
)
model_label = "Model Average"
# Only add rolling average if we have enough data points
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
# Sliding average dashed line
fig.add_scatter(
x=years,
y=sliding_averages,
mode="lines",
name="10 years rolling average",
line=dict(dash="dash"),
marker=dict(color="#d62728"),
hovertemplate=f"10-year average: %{{y:.2f}} {unit}
Year: %{{x}}"
)
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
# Compute the 10-year rolling average
rolling_window = 10
sliding_averages = (
df_model[indicator]
.rolling(window=rolling_window, min_periods=rolling_window)
.mean()
.astype(float)
.tolist()
)
model_label = f"Model : {df['model'].unique()[0]}"
# Only add rolling average if we have enough data points
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
# Sliding average dashed line
fig.add_scatter(
x=years,
y=sliding_averages,
mode="lines",
name="10 years rolling average",
line=dict(dash="dash"),
marker=dict(color="#d62728"),
hovertemplate=f"10-year average: %{{y:.2f}} {unit}
Year: %{{x}}"
)
# Indicator per year plot
fig.add_scatter(
x=years,
y=indicators,
name=f"Yearly {indicator_label}",
mode="lines",
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}
Year: %{{x}}"
)
fig.update_layout(
title=f"Plot of {indicator_label} in {location} ({model_label})",
xaxis_title="Year",
yaxis_title=f"{indicator_label} ({unit})",
template="plotly_white",
)
return fig
return plot_data
indicator_evolution_at_location: Plot = {
"name": "Indicator evolution at location",
"description": "Plot an evolution of the indicator at a certain location",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_evolution_at_location,
"sql_query": indicator_per_year_at_location_query,
}
def plot_indicator_number_of_days_per_year_at_location(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot the number of days per year for an indicator.
This function creates a bar chart showing the frequency of certain climate
events (like days above a temperature threshold) per year at a specific location.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- location (str): The location to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
location = params["location"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generate the figure thanks to the dataframe
Args:
df (pd.DataFrame): pandas dataframe with the required data
Returns:
Figure: Plotly figure
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
# Bar plot
fig.add_trace(
go.Bar(
x=years,
y=indicators,
width=0.5,
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}
Year: %{{x}}"
)
)
fig.update_layout(
title=f"{indicator_label} in {location} ({model_label})",
xaxis_title="Year",
yaxis_title=f"{indicator_label} ({unit})",
yaxis=dict(range=[0, max(indicators)]),
bargap=0.5,
template="plotly_white",
)
return fig
return plot_data
indicator_number_of_days_per_year_at_location: Plot = {
"name": "Indicator number of days per year at location",
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_number_of_days_per_year_at_location,
"sql_query": indicator_per_year_at_location_query,
}
def plot_distribution_of_indicator_for_given_year(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot the distribution of an indicator for a year.
This function creates a histogram showing the distribution of a climate
indicator across different locations for a specific year.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- year (str): The year to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
year = params["year"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
"""Generate the figure thanks to the dataframe
Args:
df (pd.DataFrame): pandas dataframe with the required data
Returns:
Figure: Plotly figure
"""
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
indicator
].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
fig.add_trace(
go.Histogram(
x=indicators,
opacity=0.8,
histnorm="percent",
marker=dict(color="#1f77b4"),
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}
Frequency: %{{y:.2f}}%"
)
)
fig.update_layout(
title=f"Distribution of {indicator_label} in {year} ({model_label})",
xaxis_title=f"{indicator_label} ({unit})",
yaxis_title="Frequency (%)",
plot_bgcolor="rgba(0, 0, 0, 0)",
showlegend=False,
)
return fig
return plot_data
distribution_of_indicator_for_given_year: Plot = {
"name": "Distribution of an indicator for a given year",
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
"params": ["indicator_column", "model", "year"],
"plot_function": plot_distribution_of_indicator_for_given_year,
"sql_query": indicator_for_given_year_query,
}
def plot_map_of_france_of_indicator_for_given_year(
params: dict,
) -> Callable[..., Figure]:
"""Generates a function to plot a map of France for an indicator.
This function creates a choropleth map of France showing the spatial
distribution of a climate indicator for a specific year.
Args:
params (dict): Dictionary containing:
- indicator_column (str): The column name for the indicator
- year (str): The year to plot
- model (str): The climate model to use
Returns:
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
"""
indicator = params["indicator_column"]
year = params["year"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
unit = INDICATOR_TO_UNIT.get(indicator, "")
def plot_data(df: pd.DataFrame) -> Figure:
fig = go.Figure()
if df['model'].nunique() != 1:
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
indicator
].mean()
indicators = df_avg[indicator].astype(float).tolist()
latitudes = df_avg["latitude"].astype(float).tolist()
longitudes = df_avg["longitude"].astype(float).tolist()
model_label = "Model Average"
else:
df_model = df
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
latitudes = df_model["latitude"].astype(float).tolist()
longitudes = df_model["longitude"].astype(float).tolist()
model_label = f"Model : {df['model'].unique()[0]}"
fig.add_trace(
go.Scattermapbox(
lat=latitudes,
lon=longitudes,
mode="markers",
marker=dict(
size=10,
color=indicators, # Color mapped to values
colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
cmin=min(indicators), # Minimum color range
cmax=max(indicators), # Maximum color range
showscale=True, # Show colorbar
),
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
hoverinfo="text" # Only show the custom text on hover
)
)
fig.update_layout(
mapbox_style="open-street-map", # Use OpenStreetMap
mapbox_zoom=3,
mapbox_center={"lat": 46.6, "lon": 2.0},
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
)
return fig
return plot_data
map_of_france_of_indicator_for_given_year: Plot = {
"name": "Map of France of an indicator for a given year",
"description": "Heatmap on the map of France of the values of an in indicator for a given year",
"params": ["indicator_column", "year", "model"],
"plot_function": plot_map_of_france_of_indicator_for_given_year,
"sql_query": indicator_for_given_year_query,
}
PLOTS = [
indicator_evolution_at_location,
indicator_number_of_days_per_year_at_location,
distribution_of_indicator_for_given_year,
map_of_france_of_indicator_for_given_year,
]