timeki commited on
Commit
63f8741
·
1 Parent(s): d561bb5

TTD : remove first unrelevant points

Browse files
Files changed (1) hide show
  1. climateqa/engine/talk_to_data/plot.py +29 -13
climateqa/engine/talk_to_data/plot.py CHANGED
@@ -81,15 +81,29 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
81
  years = df_avg["year"].astype(int).tolist()
82
 
83
  # Compute the 10-year rolling average
 
84
  sliding_averages = (
85
  df_avg[indicator]
86
- .rolling(window=10, min_periods=1)
87
  .mean()
88
  .astype(float)
89
  .tolist()
90
  )
91
  model_label = "Model Average"
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
  df_model = df
95
 
@@ -98,15 +112,28 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
98
  years = df_model["year"].astype(int).tolist()
99
 
100
  # Compute the 10-year rolling average
 
101
  sliding_averages = (
102
  df_model[indicator]
103
- .rolling(window=10, min_periods=1)
104
  .mean()
105
  .astype(float)
106
  .tolist()
107
  )
108
  model_label = f"Model : {df['model'].unique()[0]}"
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Indicator per year plot
112
  fig.add_scatter(
@@ -117,17 +144,6 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
117
  marker=dict(color="#1f77b4"),
118
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
119
  )
120
-
121
- # Sliding average dashed line
122
- fig.add_scatter(
123
- x=years,
124
- y=sliding_averages,
125
- mode="lines",
126
- name="10 years rolling average",
127
- line=dict(dash="dash"),
128
- marker=dict(color="#d62728"),
129
- hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
130
- )
131
  fig.update_layout(
132
  title=f"Plot of {indicator_label} in {location} ({model_label})",
133
  xaxis_title="Year",
 
81
  years = df_avg["year"].astype(int).tolist()
82
 
83
  # Compute the 10-year rolling average
84
+ rolling_window = 10
85
  sliding_averages = (
86
  df_avg[indicator]
87
+ .rolling(window=rolling_window, min_periods=rolling_window)
88
  .mean()
89
  .astype(float)
90
  .tolist()
91
  )
92
  model_label = "Model Average"
93
 
94
+ # Only add rolling average if we have enough data points
95
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
96
+ # Sliding average dashed line
97
+ fig.add_scatter(
98
+ x=years,
99
+ y=sliding_averages,
100
+ mode="lines",
101
+ name="10 years rolling average",
102
+ line=dict(dash="dash"),
103
+ marker=dict(color="#d62728"),
104
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
105
+ )
106
+
107
  else:
108
  df_model = df
109
 
 
112
  years = df_model["year"].astype(int).tolist()
113
 
114
  # Compute the 10-year rolling average
115
+ rolling_window = 10
116
  sliding_averages = (
117
  df_model[indicator]
118
+ .rolling(window=rolling_window, min_periods=rolling_window)
119
  .mean()
120
  .astype(float)
121
  .tolist()
122
  )
123
  model_label = f"Model : {df['model'].unique()[0]}"
124
 
125
+ # Only add rolling average if we have enough data points
126
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
127
+ # Sliding average dashed line
128
+ fig.add_scatter(
129
+ x=years,
130
+ y=sliding_averages,
131
+ mode="lines",
132
+ name="10 years rolling average",
133
+ line=dict(dash="dash"),
134
+ marker=dict(color="#d62728"),
135
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
136
+ )
137
 
138
  # Indicator per year plot
139
  fig.add_scatter(
 
144
  marker=dict(color="#1f77b4"),
145
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
146
  )
 
 
 
 
 
 
 
 
 
 
 
147
  fig.update_layout(
148
  title=f"Plot of {indicator_label} in {location} ({model_label})",
149
  xaxis_title="Year",