File size: 15,288 Bytes
f576373
31a6f8f
f576373
 
 
31a6f8f
f576373
31a6f8f
 
 
 
5fc4afc
 
 
f576373
 
 
d8a656a
 
 
 
 
 
 
 
 
 
 
 
f576373
 
 
 
 
 
 
31a6f8f
d8a656a
 
 
 
 
 
f576373
d8a656a
 
 
 
 
f576373
d8a656a
 
 
 
 
 
 
 
 
f576373
 
31a6f8f
f576373
5fc4afc
f576373
 
d8a656a
 
f576373
d8a656a
 
f576373
d8a656a
f576373
 
26bb643
f576373
 
 
 
 
 
 
63f8741
f576373
 
63f8741
f576373
 
 
 
26bb643
 
63f8741
 
 
 
 
 
 
 
 
 
 
 
 
f576373
26bb643
f576373
 
 
 
 
 
63f8741
f576373
 
63f8741
f576373
 
 
 
26bb643
 
63f8741
 
 
 
 
 
 
 
 
 
 
 
f576373
 
 
 
 
 
 
31a6f8f
5fc4afc
f576373
 
26bb643
f576373
5fc4afc
f576373
 
 
 
 
 
 
31a6f8f
 
 
f576373
31a6f8f
f576373
 
 
 
31a6f8f
 
 
d8a656a
 
 
 
 
f576373
d8a656a
 
 
 
 
f576373
d8a656a
f576373
 
31a6f8f
5fc4afc
 
31a6f8f
 
 
 
 
 
f576373
31a6f8f
 
 
f576373
26bb643
f576373
 
 
 
 
26bb643
f576373
 
26bb643
f576373
 
 
26bb643
 
f576373
 
 
 
 
 
 
 
5fc4afc
f576373
 
 
 
26bb643
f576373
5fc4afc
31a6f8f
f576373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a6f8f
 
 
d8a656a
 
 
 
 
31a6f8f
d8a656a
 
 
 
 
31a6f8f
d8a656a
31a6f8f
 
 
 
5fc4afc
31a6f8f
 
 
 
 
 
 
 
 
 
 
26bb643
31a6f8f
 
 
 
 
 
26bb643
 
31a6f8f
26bb643
31a6f8f
 
 
26bb643
 
31a6f8f
 
 
 
 
 
 
5fc4afc
31a6f8f
 
 
 
26bb643
5fc4afc
 
31a6f8f
 
 
 
 
 
 
 
 
 
 
d8a656a
31a6f8f
 
 
 
 
 
 
 
 
d8a656a
 
 
 
 
31a6f8f
d8a656a
 
 
 
 
31a6f8f
d8a656a
31a6f8f
 
 
 
5fc4afc
31a6f8f
 
 
26bb643
31a6f8f
 
 
 
 
 
 
26bb643
31a6f8f
 
26bb643
31a6f8f
 
 
 
 
26bb643
 
31a6f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc4afc
 
31a6f8f
 
 
 
 
 
 
5fc4afc
26bb643
31a6f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
from typing import Callable, TypedDict
from matplotlib.figure import figaspect
import pandas as pd
from plotly.graph_objects import Figure
import plotly.graph_objects as go
import plotly.express as px

from climateqa.engine.talk_to_data.sql_query import (
    indicator_for_given_year_query,
    indicator_per_year_at_location_query,
)
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT




class Plot(TypedDict):
    """Represents a plot configuration in the DRIAS system.
    
    This class defines the structure for configuring different types of plots
    that can be generated from climate data.
    
    Attributes:
        name (str): The name of the plot type
        description (str): A description of what the plot shows
        params (list[str]): List of required parameters for the plot
        plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
        sql_query (Callable[..., str]): Function to generate the SQL query for the plot
    """
    name: str
    description: str
    params: list[str]
    plot_function: Callable[..., Callable[..., Figure]]
    sql_query: Callable[..., str]


def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
    """Generates a function to plot indicator evolution over time at a location.
    
    This function creates a line plot showing how a climate indicator changes
    over time at a specific location. It handles temperature, precipitation,
    and other climate indicators.
    
    Args:
        params (dict): Dictionary containing:
            - indicator_column (str): The column name for the indicator
            - location (str): The location to plot
            - model (str): The climate model to use
            
    Returns:
        Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
        
    Example:
        >>> plot_func = plot_indicator_evolution_at_location({
        ...     'indicator_column': 'mean_temperature',
        ...     'location': 'Paris',
        ...     'model': 'ALL'
        ... })
        >>> fig = plot_func(df)
    """
    indicator = params["indicator_column"]
    location = params["location"]
    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
    unit = INDICATOR_TO_UNIT.get(indicator, "")

    def plot_data(df: pd.DataFrame) -> Figure:
        """Generates the actual plot from the data.
        
        Args:
            df (pd.DataFrame): DataFrame containing the data to plot
            
        Returns:
            Figure: A plotly Figure object showing the indicator evolution
        """
        fig = go.Figure()
        if df['model'].nunique() != 1:
            df_avg = df.groupby("year", as_index=False)[indicator].mean()

            # Transform to list to avoid pandas encoding
            indicators = df_avg[indicator].astype(float).tolist()
            years = df_avg["year"].astype(int).tolist()

            # Compute the 10-year rolling average
            rolling_window = 10
            sliding_averages = (
                df_avg[indicator]
                .rolling(window=rolling_window, min_periods=rolling_window)
                .mean()
                .astype(float)
                .tolist()
            )
            model_label = "Model Average"

            # Only add rolling average if we have enough data points
            if len([x for x in sliding_averages if pd.notna(x)]) > 0:
                # Sliding average dashed line
                fig.add_scatter(
                    x=years,
                    y=sliding_averages,
                    mode="lines",
                    name="10 years rolling average",
                    line=dict(dash="dash"),
                    marker=dict(color="#d62728"),
                    hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
                )

        else:
            df_model = df

            # Transform to list to avoid pandas encoding
            indicators = df_model[indicator].astype(float).tolist()
            years = df_model["year"].astype(int).tolist()

            # Compute the 10-year rolling average
            rolling_window = 10
            sliding_averages = (
                df_model[indicator]
                .rolling(window=rolling_window, min_periods=rolling_window)
                .mean()
                .astype(float)
                .tolist()
            )
            model_label = f"Model : {df['model'].unique()[0]}"

            # Only add rolling average if we have enough data points
            if len([x for x in sliding_averages if pd.notna(x)]) > 0:
                # Sliding average dashed line
                fig.add_scatter(
                    x=years,
                    y=sliding_averages,
                    mode="lines",
                    name="10 years rolling average",
                    line=dict(dash="dash"),
                    marker=dict(color="#d62728"),
                    hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
                )

        # Indicator per year plot
        fig.add_scatter(
            x=years,
            y=indicators,
            name=f"Yearly {indicator_label}",
            mode="lines",
            marker=dict(color="#1f77b4"),
            hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
        )
        fig.update_layout(
            title=f"Plot of {indicator_label} in {location} ({model_label})",
            xaxis_title="Year",
            yaxis_title=f"{indicator_label} ({unit})",
            template="plotly_white",
        )
        return fig

    return plot_data


indicator_evolution_at_location: Plot = {
    "name": "Indicator evolution at location",
    "description": "Plot an evolution of the indicator at a certain location",
    "params": ["indicator_column", "location", "model"],
    "plot_function": plot_indicator_evolution_at_location,
    "sql_query": indicator_per_year_at_location_query,
}


def plot_indicator_number_of_days_per_year_at_location(
    params: dict,
) -> Callable[..., Figure]:
    """Generates a function to plot the number of days per year for an indicator.
    
    This function creates a bar chart showing the frequency of certain climate
    events (like days above a temperature threshold) per year at a specific location.
    
    Args:
        params (dict): Dictionary containing:
            - indicator_column (str): The column name for the indicator
            - location (str): The location to plot
            - model (str): The climate model to use
            
    Returns:
        Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
    """
    indicator = params["indicator_column"]
    location = params["location"]
    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
    unit = INDICATOR_TO_UNIT.get(indicator, "")

    def plot_data(df: pd.DataFrame) -> Figure:
        """Generate the figure thanks to the dataframe

        Args:
            df (pd.DataFrame): pandas dataframe with the required data

        Returns:
            Figure: Plotly figure
        """
        fig = go.Figure()
        if df['model'].nunique() != 1:
            df_avg = df.groupby("year", as_index=False)[indicator].mean()

            # Transform to list to avoid pandas encoding
            indicators = df_avg[indicator].astype(float).tolist()
            years = df_avg["year"].astype(int).tolist()
            model_label = "Model Average"

        else:
            df_model = df
            # Transform to list to avoid pandas encoding
            indicators = df_model[indicator].astype(float).tolist()
            years = df_model["year"].astype(int).tolist()
            model_label = f"Model : {df['model'].unique()[0]}"


        # Bar plot
        fig.add_trace(
            go.Bar(
                x=years,
                y=indicators,
                width=0.5,
                marker=dict(color="#1f77b4"),
                hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
            )
        )

        fig.update_layout(
            title=f"{indicator_label} in {location} ({model_label})",
            xaxis_title="Year",
            yaxis_title=f"{indicator_label} ({unit})",
            yaxis=dict(range=[0, max(indicators)]),
            bargap=0.5,
            template="plotly_white",
        )

        return fig

    return plot_data


indicator_number_of_days_per_year_at_location: Plot = {
    "name": "Indicator number of days per year at location",
    "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
    "params": ["indicator_column", "location", "model"],
    "plot_function": plot_indicator_number_of_days_per_year_at_location,
    "sql_query": indicator_per_year_at_location_query,
}


def plot_distribution_of_indicator_for_given_year(
    params: dict,
) -> Callable[..., Figure]:
    """Generates a function to plot the distribution of an indicator for a year.
    
    This function creates a histogram showing the distribution of a climate
    indicator across different locations for a specific year.
    
    Args:
        params (dict): Dictionary containing:
            - indicator_column (str): The column name for the indicator
            - year (str): The year to plot
            - model (str): The climate model to use
            
    Returns:
        Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
    """
    indicator = params["indicator_column"]
    year = params["year"]
    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
    unit = INDICATOR_TO_UNIT.get(indicator, "")

    def plot_data(df: pd.DataFrame) -> Figure:
        """Generate the figure thanks to the dataframe

        Args:
            df (pd.DataFrame): pandas dataframe with the required data

        Returns:
            Figure: Plotly figure
        """
        fig = go.Figure()
        if df['model'].nunique() != 1:
            df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
                indicator
            ].mean()

            # Transform to list to avoid pandas encoding
            indicators = df_avg[indicator].astype(float).tolist()
            model_label = "Model Average"

        else:
            df_model = df

            # Transform to list to avoid pandas encoding
            indicators = df_model[indicator].astype(float).tolist()
            model_label = f"Model : {df['model'].unique()[0]}"


        fig.add_trace(
            go.Histogram(
                x=indicators,
                opacity=0.8,
                histnorm="percent",
                marker=dict(color="#1f77b4"),
                hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
            )
        )

        fig.update_layout(
            title=f"Distribution of {indicator_label} in {year} ({model_label})",
            xaxis_title=f"{indicator_label} ({unit})",
            yaxis_title="Frequency (%)",
            plot_bgcolor="rgba(0, 0, 0, 0)",
            showlegend=False,
        )

        return fig

    return plot_data


distribution_of_indicator_for_given_year: Plot = {
    "name": "Distribution of an indicator for a given year",
    "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
    "params": ["indicator_column", "model", "year"],
    "plot_function": plot_distribution_of_indicator_for_given_year,
    "sql_query": indicator_for_given_year_query,
}


def plot_map_of_france_of_indicator_for_given_year(
    params: dict,
) -> Callable[..., Figure]:
    """Generates a function to plot a map of France for an indicator.
    
    This function creates a choropleth map of France showing the spatial
    distribution of a climate indicator for a specific year.
    
    Args:
        params (dict): Dictionary containing:
            - indicator_column (str): The column name for the indicator
            - year (str): The year to plot
            - model (str): The climate model to use
            
    Returns:
        Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
    """
    indicator = params["indicator_column"]
    year = params["year"]
    indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
    unit = INDICATOR_TO_UNIT.get(indicator, "")

    def plot_data(df: pd.DataFrame) -> Figure:
        fig = go.Figure()
        if df['model'].nunique() != 1:
            df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
                indicator
            ].mean()

            indicators = df_avg[indicator].astype(float).tolist()
            latitudes = df_avg["latitude"].astype(float).tolist()
            longitudes = df_avg["longitude"].astype(float).tolist()
            model_label = "Model Average"

        else:
            df_model = df

            # Transform to list to avoid pandas encoding
            indicators = df_model[indicator].astype(float).tolist()
            latitudes = df_model["latitude"].astype(float).tolist()
            longitudes = df_model["longitude"].astype(float).tolist()
            model_label = f"Model : {df['model'].unique()[0]}"


        fig.add_trace(
            go.Scattermapbox(
                lat=latitudes,
                lon=longitudes,
                mode="markers",
                marker=dict(
                    size=10,
                    color=indicators,  # Color mapped to values
                    colorscale="Turbo",  # Color scale (can be 'Plasma', 'Jet', etc.)
                    cmin=min(indicators),  # Minimum color range
                    cmax=max(indicators),  # Maximum color range
                    showscale=True,  # Show colorbar
                ),
                text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators],  # Add hover text showing the indicator value
                hoverinfo="text"  # Only show the custom text on hover
            )
        )

        fig.update_layout(
            mapbox_style="open-street-map",  # Use OpenStreetMap
            mapbox_zoom=3,
            mapbox_center={"lat": 46.6, "lon": 2.0},
            coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),  # Add legend
            title=f"{indicator_label} in {year} in France ({model_label}) " # Title
        )
        return fig

    return plot_data


map_of_france_of_indicator_for_given_year: Plot = {
    "name": "Map of France of an indicator for a given year",
    "description": "Heatmap on the map of France of the values of an in indicator for a given year",
    "params": ["indicator_column", "year", "model"],
    "plot_function": plot_map_of_france_of_indicator_for_given_year,
    "sql_query": indicator_for_given_year_query,
}


PLOTS = [
    indicator_evolution_at_location,
    indicator_number_of_days_per_year_at_location,
    distribution_of_indicator_for_given_year,
    map_of_france_of_indicator_for_given_year,
]