armanddemasson commited on
Commit
26bb643
·
1 Parent(s): 3c5863d

feat: model filtering and UI upgrade for TTD

Browse files
app.py CHANGED
@@ -9,14 +9,14 @@ from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
- from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
  from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
- from front.tabs import (create_config_modal, cqa_tab, create_about_tab)
19
- from front.tabs import (MainTabPanel, ConfigPanel)
20
  from front.utils import process_figures
21
  from gradio_modal import Modal
22
 
@@ -25,14 +25,14 @@ from utils import create_user_id
25
  import logging
26
 
27
  logging.basicConfig(level=logging.WARNING)
28
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs
29
  logging.getLogger().setLevel(logging.WARNING)
30
 
31
 
32
-
33
  # Load environment variables in local mode
34
  try:
35
  from dotenv import load_dotenv
 
36
  load_dotenv()
37
  except Exception as e:
38
  pass
@@ -63,42 +63,105 @@ share_client = service.get_share_client(file_share_name)
63
  user_id = create_user_id()
64
 
65
 
66
-
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
70
- vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
71
- vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
 
 
 
 
 
 
 
 
72
 
73
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
74
  if os.environ["GRADIO_ENV"] == "local":
75
  reranker = get_reranker("nano")
76
- else :
77
  reranker = get_reranker("large")
78
 
79
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
80
- agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
81
-
82
- #Vanna object
83
-
84
- vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
85
- db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
86
- vn.connect_to_sqlite(db_vanna_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # def ask_vanna_query(query):
89
  # return ask_vanna(vn, db_vanna_path, query)
90
 
91
- def ask_drias_query(query: str, index_state: int, drias_model: str):
92
- return ask_drias(db_vanna_path, query, index_state, drias_model)
93
 
94
- async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
 
 
 
95
  print("chat cqa - message received")
96
- async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
97
  yield event
98
-
99
- async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
100
  print("chat poc - message received")
101
- async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
102
  yield event
103
 
104
 
@@ -106,14 +169,17 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
106
  # Gradio
107
  # --------------------------------------------------------------------
108
 
 
109
  # Function to update modal visibility
110
  def update_config_modal_visibility(config_open):
111
  print(config_open)
112
  new_config_visibility_status = not config_open
113
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
114
-
115
 
116
- def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
 
 
 
117
  sources_number = sources_textbox.count("<h2>")
118
  figures_number = figures_cards.count("<h2>")
119
  graphs_number = current_graphs.count("<iframe")
@@ -122,9 +188,18 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
122
  figures_notif_label = f"Figures ({figures_number})"
123
  graphs_notif_label = f"Graphs ({graphs_number})"
124
  papers_notif_label = f"Papers ({papers_number})"
125
- recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
 
 
 
 
 
 
 
 
 
 
126
 
127
- return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
128
 
129
  # def create_drias_tab():
130
  # with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
@@ -141,24 +216,112 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
141
  # vanna_display = gr.Plot()
142
  # vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def create_drias_tab():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
 
 
 
 
146
  with gr.Row():
147
- drias_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True)
148
- model_selection = gr.Dropdown(label="Model", choices=DRIAS_MODELS ,elem_id="drias-model", value="ALL", interactive=True)
149
-
150
- with gr.Accordion(label="SQL Query Used"):
151
- drias_sql_query = gr.Textbox(label="", elem_id="sql-query", interactive=False)
 
 
 
 
 
152
 
153
- with gr.Accordion(label='Data used', open=False):
154
- drias_table = gr.DataFrame([], elem_id="vanna-table")
 
 
155
 
156
- with gr.Accordion(label="Chart"):
 
 
 
157
  drias_display = gr.Plot(elem_id="vanna-plot")
158
-
 
 
 
 
 
 
 
159
  with gr.Row():
160
- prev_button = gr.Button("Previous")
161
- next_button = gr.Button("Next")
162
 
163
  sql_queries_state = gr.State([])
164
  dataframes_state = gr.State([])
@@ -166,96 +329,104 @@ def create_drias_tab():
166
  index_state = gr.State(0)
167
 
168
  drias_direct_question.submit(
169
- ask_drias_query,
170
- inputs=[drias_direct_question, index_state, model_selection],
171
- outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  )
173
 
174
  model_selection.change(
175
- ask_drias_query,
176
- inputs=[drias_direct_question, index_state, model_selection],
177
- outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
178
  )
179
 
180
  def show_previous(index, sql_queries, dataframes, plots):
181
  if index > 0:
182
  index -= 1
183
- return sql_queries[index], dataframes[index], plots[index], index
 
 
 
 
 
184
 
185
  def show_next(index, sql_queries, dataframes, plots):
186
  if index < len(sql_queries) - 1:
187
  index += 1
188
- return sql_queries[index], dataframes[index], plots[index], index
 
 
 
 
 
189
 
190
  prev_button.click(
191
- show_previous,
192
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
193
- outputs=[drias_sql_query, drias_table, drias_display, index_state]
 
 
 
 
194
  )
195
 
196
  next_button.click(
197
- show_next,
198
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
199
- outputs=[drias_sql_query, drias_table, drias_display, index_state]
 
 
 
 
200
  )
201
-
202
- # # UI Layout Components
203
- def cqa_tab(tab_name):
204
- # State variables
205
- current_graphs = gr.State([])
206
- with gr.Tab(tab_name):
207
- with gr.Row(elem_id="chatbot-row"):
208
- # Left column - Chat interface
209
- with gr.Column(scale=2):
210
- chatbot, textbox, config_button = create_chat_interface(tab_name)
211
-
212
- # Right column - Content panels
213
- with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
214
- with gr.Tabs(elem_id="right_panel_tab") as tabs:
215
- # Examples tab
216
- with gr.TabItem("Examples", elem_id="tab-examples", id=0):
217
- examples_hidden = create_examples_tab(tab_name)
218
-
219
- # Sources tab
220
- with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
221
- sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
222
-
223
-
224
- # Recommended content tab
225
- with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
226
- with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
227
- # Figures subtab
228
- with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
229
- sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
230
-
231
- # Papers subtab
232
- with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
233
- papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
234
-
235
- # Graphs subtab
236
- with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
237
- graphs_container = gr.HTML(
238
- "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
239
- elem_id="graphs-container"
240
- )
241
-
242
-
243
- def config_event_handling(main_tabs_components : list[MainTabPanel], config_componenets : ConfigPanel):
244
  config_open = config_componenets.config_open
245
  config_modal = config_componenets.config_modal
246
  close_config_modal = config_componenets.close_config_modal_button
247
-
248
- for button in [close_config_modal] + [main_tab_component.config_button for main_tab_component in main_tabs_components]:
 
 
249
  button.click(
250
  fn=update_config_modal_visibility,
251
  inputs=[config_open],
252
- outputs=[config_modal, config_open]
253
- )
254
-
 
255
  def event_handling(
256
- main_tab_components : MainTabPanel,
257
- config_components : ConfigPanel,
258
- tab_name="ClimateQ&A"
259
  ):
260
  chatbot = main_tab_components.chatbot
261
  textbox = main_tab_components.textbox
@@ -279,7 +450,7 @@ def event_handling(
279
  graphs_container = main_tab_components.graph_container
280
  follow_up_examples = main_tab_components.follow_up_examples
281
  follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
282
-
283
  dropdown_sources = config_components.dropdown_sources
284
  dropdown_reports = config_components.dropdown_reports
285
  dropdown_external_sources = config_components.dropdown_external_sources
@@ -288,91 +459,302 @@ def event_handling(
288
  after = config_components.after
289
  output_query = config_components.output_query
290
  output_language = config_components.output_language
291
-
292
  new_sources_hmtl = gr.State([])
293
  ttd_data = gr.State([])
294
 
295
-
296
  if tab_name == "ClimateQ&A":
297
  print("chat cqa - message sent")
298
 
299
  # Event for textbox
300
- (textbox
301
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
302
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
303
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
  # Event for examples_hidden
306
- (examples_hidden
307
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
308
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
309
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  )
311
- (follow_up_examples_hidden
312
- .change(start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
313
- .then(chat, [follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
314
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
316
-
317
  elif tab_name == "Beta - POC Adapt'Action":
318
  print("chat poc - message sent")
319
  # Event for textbox
320
- (textbox
321
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
322
- .then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
323
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  )
325
  # Event for examples_hidden
326
- (examples_hidden
327
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
328
- .then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
329
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  )
331
- (follow_up_examples_hidden
332
- .change(start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
333
- .then(chat, [follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
334
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  )
336
-
337
-
338
- new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
339
- current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
340
- new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
 
 
 
 
 
 
 
341
 
342
  # Update sources numbers
343
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
344
- component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
345
-
 
 
 
 
346
  # Search for papers
347
  for component in [textbox, examples_hidden, papers_direct_search]:
348
- component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
349
-
 
 
 
350
 
351
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
352
  # # Drias search
353
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
354
 
 
355
  def main_ui():
356
  # config_open = gr.State(True)
357
- with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo:
358
- config_components = create_config_modal()
359
-
 
 
 
 
 
360
  with gr.Tabs():
361
- cqa_components = cqa_tab(tab_name = "ClimateQ&A")
362
- local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
363
  create_drias_tab()
364
-
365
  create_about_tab()
366
-
367
- event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
368
- event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action")
369
-
370
- config_event_handling([cqa_components,local_cqa_components] ,config_components)
371
-
 
 
372
  demo.queue()
373
-
374
  return demo
375
 
376
-
377
  demo = main_ui()
378
  demo.launch(ssr_mode=False)
 
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
+ from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
  from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
+ from front.tabs import create_config_modal, cqa_tab, create_about_tab
19
+ from front.tabs import MainTabPanel, ConfigPanel
20
  from front.utils import process_figures
21
  from gradio_modal import Modal
22
 
 
25
  import logging
26
 
27
  logging.basicConfig(level=logging.WARNING)
28
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
29
  logging.getLogger().setLevel(logging.WARNING)
30
 
31
 
 
32
  # Load environment variables in local mode
33
  try:
34
  from dotenv import load_dotenv
35
+
36
  load_dotenv()
37
  except Exception as e:
38
  pass
 
63
  user_id = create_user_id()
64
 
65
 
 
66
  # Create vectorstore and retriever
67
  embeddings_function = get_embeddings_function()
68
+ vectorstore = get_pinecone_vectorstore(
69
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
70
+ )
71
+ vectorstore_graphs = get_pinecone_vectorstore(
72
+ embeddings_function,
73
+ index_name=os.getenv("PINECONE_API_INDEX_OWID"),
74
+ text_key="description",
75
+ )
76
+ vectorstore_region = get_pinecone_vectorstore(
77
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
78
+ )
79
 
80
+ llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
81
  if os.environ["GRADIO_ENV"] == "local":
82
  reranker = get_reranker("nano")
83
+ else:
84
  reranker = get_reranker("large")
85
 
86
+ agent = make_graph_agent(
87
+ llm=llm,
88
+ vectorstore_ipcc=vectorstore,
89
+ vectorstore_graphs=vectorstore_graphs,
90
+ vectorstore_region=vectorstore_region,
91
+ reranker=reranker,
92
+ threshold_docs=0.2,
93
+ )
94
+ agent_poc = make_graph_agent_poc(
95
+ llm=llm,
96
+ vectorstore_ipcc=vectorstore,
97
+ vectorstore_graphs=vectorstore_graphs,
98
+ vectorstore_region=vectorstore_region,
99
+ reranker=reranker,
100
+ threshold_docs=0,
101
+ version="v4",
102
+ ) # TODO put back default 0.2
103
+
104
+ # Vanna object
105
+
106
+ # vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
107
+ # db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
108
+ # vn.connect_to_sqlite(db_vanna_path)
109
 
110
  # def ask_vanna_query(query):
111
  # return ask_vanna(vn, db_vanna_path, query)
112
 
 
 
113
 
114
+ def ask_drias_query(query: str, index_state: int):
115
+ return ask_drias(query, index_state)
116
+
117
+
118
+ async def chat(
119
+ query,
120
+ history,
121
+ audience,
122
+ sources,
123
+ reports,
124
+ relevant_content_sources_selection,
125
+ search_only,
126
+ ):
127
  print("chat cqa - message received")
128
+ async for event in chat_stream(
129
+ agent,
130
+ query,
131
+ history,
132
+ audience,
133
+ sources,
134
+ reports,
135
+ relevant_content_sources_selection,
136
+ search_only,
137
+ share_client,
138
+ user_id,
139
+ ):
140
  yield event
141
+
142
+
143
+ async def chat_poc(
144
+ query,
145
+ history,
146
+ audience,
147
+ sources,
148
+ reports,
149
+ relevant_content_sources_selection,
150
+ search_only,
151
+ ):
152
  print("chat poc - message received")
153
+ async for event in chat_stream(
154
+ agent_poc,
155
+ query,
156
+ history,
157
+ audience,
158
+ sources,
159
+ reports,
160
+ relevant_content_sources_selection,
161
+ search_only,
162
+ share_client,
163
+ user_id,
164
+ ):
165
  yield event
166
 
167
 
 
169
  # Gradio
170
  # --------------------------------------------------------------------
171
 
172
+
173
  # Function to update modal visibility
174
  def update_config_modal_visibility(config_open):
175
  print(config_open)
176
  new_config_visibility_status = not config_open
177
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
 
178
 
179
+
180
+ def update_sources_number_display(
181
+ sources_textbox, figures_cards, current_graphs, papers_html
182
+ ):
183
  sources_number = sources_textbox.count("<h2>")
184
  figures_number = figures_cards.count("<h2>")
185
  graphs_number = current_graphs.count("<iframe")
 
188
  figures_notif_label = f"Figures ({figures_number})"
189
  graphs_notif_label = f"Graphs ({graphs_number})"
190
  papers_notif_label = f"Papers ({papers_number})"
191
+ recommended_content_notif_label = (
192
+ f"Recommended content ({figures_number + graphs_number + papers_number})"
193
+ )
194
+
195
+ return (
196
+ gr.update(label=recommended_content_notif_label),
197
+ gr.update(label=sources_notif_label),
198
+ gr.update(label=figures_notif_label),
199
+ gr.update(label=graphs_notif_label),
200
+ gr.update(label=papers_notif_label),
201
+ )
202
 
 
203
 
204
  # def create_drias_tab():
205
  # with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
 
216
  # vanna_display = gr.Plot()
217
  # vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
218
 
219
+
220
+ def show_results(sql_queries_state, dataframes_state, plots_state):
221
+ if not sql_queries_state or not dataframes_state or not plots_state:
222
+ # If all results are empty, show "No result"
223
+ return (
224
+ gr.update(visible=True),
225
+ gr.update(visible=False),
226
+ gr.update(visible=False),
227
+ gr.update(visible=False),
228
+ gr.update(visible=False),
229
+ gr.update(visible=False),
230
+ gr.update(visible=False),
231
+ )
232
+ else:
233
+ # Show the appropriate components with their data
234
+ return (
235
+ gr.update(visible=False),
236
+ gr.update(visible=True),
237
+ gr.update(visible=True),
238
+ gr.update(visible=True),
239
+ gr.update(visible=True),
240
+ gr.update(visible=True),
241
+ gr.update(visible=True),
242
+ )
243
+
244
+
245
+ def filter_by_model(dataframes, figures, index_state, model_selection):
246
+ df = dataframes[index_state]
247
+ if model_selection != "ALL":
248
+ df = df[df["model"] == model_selection]
249
+ figure = figures[index_state](df)
250
+ return df, figure
251
+
252
+
253
+ def update_pagination(index, sql_queries):
254
+ pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
255
+ return pagination
256
+
257
+
258
  def create_drias_tab():
259
+ details_text = """
260
+ Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
261
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
262
+
263
+ ❓ **How to use?**
264
+ You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
265
+ You can specify **location** and/or **year**.
266
+ You can choose from a list of climate models. By default, we take the **average of each model**.
267
+
268
+ For example, you can ask:
269
+ - What will the temperature be like in Paris?
270
+ - What will be the total rainfall in France in 2030?
271
+ - How frequent will extreme events be in Lyon?
272
+
273
+ **Example of indicators in the data**:
274
+ - Mean temperature (annual, winter, summer)
275
+ - Total precipitation (annual, winter, summer)
276
+ - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
277
+
278
+ ⚠️ **Limitations**:
279
+ - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
280
+ - You can only ask about **locations in France**.
281
+ - If you specify a year, there may be **no data for that year for some models**.
282
+ - You **cannot compare two models**.
283
+
284
+ 🛈 **Information**
285
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
286
+ """
287
  with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
288
+
289
+ with gr.Accordion(label="Details"):
290
+ gr.Markdown(details_text)
291
+
292
  with gr.Row():
293
+ drias_direct_question = gr.Textbox(
294
+ label="Direct Question",
295
+ placeholder="You can write direct question here",
296
+ elem_id="direct-question",
297
+ interactive=True,
298
+ )
299
+
300
+ result_text = gr.Textbox(
301
+ label="", elem_id="no-result-label", interactive=False, visible=True
302
+ )
303
 
304
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
305
+ drias_sql_query = gr.Textbox(
306
+ label="", elem_id="sql-query", interactive=False
307
+ )
308
 
309
+ with gr.Accordion(label="Chart", visible=False) as chart_accordion:
310
+ model_selection = gr.Dropdown(
311
+ label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
312
+ )
313
  drias_display = gr.Plot(elem_id="vanna-plot")
314
+
315
+ with gr.Accordion(
316
+ label="Data used", open=False, visible=False
317
+ ) as table_accordion:
318
+ drias_table = gr.DataFrame([], elem_id="vanna-table")
319
+
320
+ pagination_display = gr.Markdown(value="", visible=False, elem_id="pagination-display")
321
+
322
  with gr.Row():
323
+ prev_button = gr.Button("Previous", visible=False)
324
+ next_button = gr.Button("Next", visible=False)
325
 
326
  sql_queries_state = gr.State([])
327
  dataframes_state = gr.State([])
 
329
  index_state = gr.State(0)
330
 
331
  drias_direct_question.submit(
332
+ ask_drias_query,
333
+ inputs=[drias_direct_question, index_state],
334
+ outputs=[
335
+ drias_sql_query,
336
+ drias_table,
337
+ drias_display,
338
+ sql_queries_state,
339
+ dataframes_state,
340
+ plots_state,
341
+ index_state,
342
+ result_text,
343
+ ],
344
+ ).then(
345
+ show_results,
346
+ inputs=[sql_queries_state, dataframes_state, plots_state],
347
+ outputs=[
348
+ result_text,
349
+ query_accordion,
350
+ table_accordion,
351
+ chart_accordion,
352
+ prev_button,
353
+ next_button,
354
+ pagination_display
355
+ ],
356
+ ).then(
357
+ update_pagination,
358
+ inputs=[index_state, sql_queries_state],
359
+ outputs=[pagination_display],
360
  )
361
 
362
  model_selection.change(
363
+ filter_by_model,
364
+ inputs=[dataframes_state, plots_state, index_state, model_selection],
365
+ outputs=[drias_table, drias_display],
366
  )
367
 
368
  def show_previous(index, sql_queries, dataframes, plots):
369
  if index > 0:
370
  index -= 1
371
+ return (
372
+ sql_queries[index],
373
+ dataframes[index],
374
+ plots[index](dataframes[index]),
375
+ index,
376
+ )
377
 
378
  def show_next(index, sql_queries, dataframes, plots):
379
  if index < len(sql_queries) - 1:
380
  index += 1
381
+ return (
382
+ sql_queries[index],
383
+ dataframes[index],
384
+ plots[index](dataframes[index]),
385
+ index,
386
+ )
387
 
388
  prev_button.click(
389
+ show_previous,
390
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
391
+ outputs=[drias_sql_query, drias_table, drias_display, index_state],
392
+ ).then(
393
+ update_pagination,
394
+ inputs=[index_state, sql_queries_state],
395
+ outputs=[pagination_display],
396
  )
397
 
398
  next_button.click(
399
+ show_next,
400
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
401
+ outputs=[drias_sql_query, drias_table, drias_display, index_state],
402
+ ).then(
403
+ update_pagination,
404
+ inputs=[index_state, sql_queries_state],
405
+ outputs=[pagination_display],
406
  )
407
+
408
+
409
+ def config_event_handling(
410
+ main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
411
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  config_open = config_componenets.config_open
413
  config_modal = config_componenets.config_modal
414
  close_config_modal = config_componenets.close_config_modal_button
415
+
416
+ for button in [close_config_modal] + [
417
+ main_tab_component.config_button for main_tab_component in main_tabs_components
418
+ ]:
419
  button.click(
420
  fn=update_config_modal_visibility,
421
  inputs=[config_open],
422
+ outputs=[config_modal, config_open],
423
+ )
424
+
425
+
426
  def event_handling(
427
+ main_tab_components: MainTabPanel,
428
+ config_components: ConfigPanel,
429
+ tab_name="ClimateQ&A",
430
  ):
431
  chatbot = main_tab_components.chatbot
432
  textbox = main_tab_components.textbox
 
450
  graphs_container = main_tab_components.graph_container
451
  follow_up_examples = main_tab_components.follow_up_examples
452
  follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
453
+
454
  dropdown_sources = config_components.dropdown_sources
455
  dropdown_reports = config_components.dropdown_reports
456
  dropdown_external_sources = config_components.dropdown_external_sources
 
459
  after = config_components.after
460
  output_query = config_components.output_query
461
  output_language = config_components.output_language
462
+
463
  new_sources_hmtl = gr.State([])
464
  ttd_data = gr.State([])
465
 
 
466
  if tab_name == "ClimateQ&A":
467
  print("chat cqa - message sent")
468
 
469
  # Event for textbox
470
+ (
471
+ textbox.submit(
472
+ start_chat,
473
+ [textbox, chatbot, search_only],
474
+ [textbox, tabs, chatbot, sources_raw],
475
+ queue=False,
476
+ api_name=f"start_chat_{textbox.elem_id}",
477
+ )
478
+ .then(
479
+ chat,
480
+ [
481
+ textbox,
482
+ chatbot,
483
+ dropdown_audience,
484
+ dropdown_sources,
485
+ dropdown_reports,
486
+ dropdown_external_sources,
487
+ search_only,
488
+ ],
489
+ [
490
+ chatbot,
491
+ new_sources_hmtl,
492
+ output_query,
493
+ output_language,
494
+ new_figures,
495
+ current_graphs,
496
+ follow_up_examples.dataset,
497
+ ],
498
+ concurrency_limit=8,
499
+ api_name=f"chat_{textbox.elem_id}",
500
+ )
501
+ .then(
502
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
503
+ )
504
  )
505
  # Event for examples_hidden
506
+ (
507
+ examples_hidden.change(
508
+ start_chat,
509
+ [examples_hidden, chatbot, search_only],
510
+ [examples_hidden, tabs, chatbot, sources_raw],
511
+ queue=False,
512
+ api_name=f"start_chat_{examples_hidden.elem_id}",
513
+ )
514
+ .then(
515
+ chat,
516
+ [
517
+ examples_hidden,
518
+ chatbot,
519
+ dropdown_audience,
520
+ dropdown_sources,
521
+ dropdown_reports,
522
+ dropdown_external_sources,
523
+ search_only,
524
+ ],
525
+ [
526
+ chatbot,
527
+ new_sources_hmtl,
528
+ output_query,
529
+ output_language,
530
+ new_figures,
531
+ current_graphs,
532
+ follow_up_examples.dataset,
533
+ ],
534
+ concurrency_limit=8,
535
+ api_name=f"chat_{examples_hidden.elem_id}",
536
+ )
537
+ .then(
538
+ finish_chat,
539
+ None,
540
+ [textbox],
541
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
542
+ )
543
  )
544
+ (
545
+ follow_up_examples_hidden.change(
546
+ start_chat,
547
+ [follow_up_examples_hidden, chatbot, search_only],
548
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
549
+ queue=False,
550
+ api_name=f"start_chat_{examples_hidden.elem_id}",
551
+ )
552
+ .then(
553
+ chat,
554
+ [
555
+ follow_up_examples_hidden,
556
+ chatbot,
557
+ dropdown_audience,
558
+ dropdown_sources,
559
+ dropdown_reports,
560
+ dropdown_external_sources,
561
+ search_only,
562
+ ],
563
+ [
564
+ chatbot,
565
+ new_sources_hmtl,
566
+ output_query,
567
+ output_language,
568
+ new_figures,
569
+ current_graphs,
570
+ follow_up_examples.dataset,
571
+ ],
572
+ concurrency_limit=8,
573
+ api_name=f"chat_{examples_hidden.elem_id}",
574
+ )
575
+ .then(
576
+ finish_chat,
577
+ None,
578
+ [textbox],
579
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
580
+ )
581
  )
582
+
583
  elif tab_name == "Beta - POC Adapt'Action":
584
  print("chat poc - message sent")
585
  # Event for textbox
586
+ (
587
+ textbox.submit(
588
+ start_chat,
589
+ [textbox, chatbot, search_only],
590
+ [textbox, tabs, chatbot, sources_raw],
591
+ queue=False,
592
+ api_name=f"start_chat_{textbox.elem_id}",
593
+ )
594
+ .then(
595
+ chat_poc,
596
+ [
597
+ textbox,
598
+ chatbot,
599
+ dropdown_audience,
600
+ dropdown_sources,
601
+ dropdown_reports,
602
+ dropdown_external_sources,
603
+ search_only,
604
+ ],
605
+ [
606
+ chatbot,
607
+ new_sources_hmtl,
608
+ output_query,
609
+ output_language,
610
+ new_figures,
611
+ current_graphs,
612
+ ],
613
+ concurrency_limit=8,
614
+ api_name=f"chat_{textbox.elem_id}",
615
+ )
616
+ .then(
617
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
618
+ )
619
  )
620
  # Event for examples_hidden
621
+ (
622
+ examples_hidden.change(
623
+ start_chat,
624
+ [examples_hidden, chatbot, search_only],
625
+ [examples_hidden, tabs, chatbot, sources_raw],
626
+ queue=False,
627
+ api_name=f"start_chat_{examples_hidden.elem_id}",
628
+ )
629
+ .then(
630
+ chat_poc,
631
+ [
632
+ examples_hidden,
633
+ chatbot,
634
+ dropdown_audience,
635
+ dropdown_sources,
636
+ dropdown_reports,
637
+ dropdown_external_sources,
638
+ search_only,
639
+ ],
640
+ [
641
+ chatbot,
642
+ new_sources_hmtl,
643
+ output_query,
644
+ output_language,
645
+ new_figures,
646
+ current_graphs,
647
+ ],
648
+ concurrency_limit=8,
649
+ api_name=f"chat_{examples_hidden.elem_id}",
650
+ )
651
+ .then(
652
+ finish_chat,
653
+ None,
654
+ [textbox],
655
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
656
+ )
657
  )
658
+ (
659
+ follow_up_examples_hidden.change(
660
+ start_chat,
661
+ [follow_up_examples_hidden, chatbot, search_only],
662
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
663
+ queue=False,
664
+ api_name=f"start_chat_{examples_hidden.elem_id}",
665
+ )
666
+ .then(
667
+ chat,
668
+ [
669
+ follow_up_examples_hidden,
670
+ chatbot,
671
+ dropdown_audience,
672
+ dropdown_sources,
673
+ dropdown_reports,
674
+ dropdown_external_sources,
675
+ search_only,
676
+ ],
677
+ [
678
+ chatbot,
679
+ new_sources_hmtl,
680
+ output_query,
681
+ output_language,
682
+ new_figures,
683
+ current_graphs,
684
+ follow_up_examples.dataset,
685
+ ],
686
+ concurrency_limit=8,
687
+ api_name=f"chat_{examples_hidden.elem_id}",
688
+ )
689
+ .then(
690
+ finish_chat,
691
+ None,
692
+ [textbox],
693
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
694
+ )
695
  )
696
+
697
+ new_sources_hmtl.change(
698
+ lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
699
+ )
700
+ current_graphs.change(
701
+ lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
702
+ )
703
+ new_figures.change(
704
+ process_figures,
705
+ inputs=[sources_raw, new_figures],
706
+ outputs=[sources_raw, figures_cards, gallery_component],
707
+ )
708
 
709
  # Update sources numbers
710
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
711
+ component.change(
712
+ update_sources_number_display,
713
+ [sources_textbox, figures_cards, current_graphs, papers_html],
714
+ [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
715
+ )
716
+
717
  # Search for papers
718
  for component in [textbox, examples_hidden, papers_direct_search]:
719
+ component.submit(
720
+ find_papers,
721
+ [component, after, dropdown_external_sources],
722
+ [papers_html, citations_network, papers_summary],
723
+ )
724
 
725
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
726
  # # Drias search
727
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
728
 
729
+
730
  def main_ui():
731
  # config_open = gr.State(True)
732
+ with gr.Blocks(
733
+ title="Climate Q&A",
734
+ css_paths=os.getcwd() + "/style.css",
735
+ theme=theme,
736
+ elem_id="main-component",
737
+ ) as demo:
738
+ config_components = create_config_modal()
739
+
740
  with gr.Tabs():
741
+ cqa_components = cqa_tab(tab_name="ClimateQ&A")
742
+ local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action")
743
  create_drias_tab()
744
+
745
  create_about_tab()
746
+
747
+ event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
748
+ event_handling(
749
+ local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action"
750
+ )
751
+
752
+ config_event_handling([cqa_components, local_cqa_components], config_components)
753
+
754
  demo.queue()
755
+
756
  return demo
757
 
758
+
759
  demo = main_ui()
760
  demo.launch(ssr_mode=False)
climateqa/engine/talk_to_data/config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DRIAS_TABLES = [
2
+ "total_winter_precipitation",
3
+ "total_summer_precipiation",
4
+ "total_annual_precipitation",
5
+ "total_remarkable_daily_precipitation",
6
+ "frequency_of_remarkable_daily_precipitation",
7
+ "extreme_precipitation_intensity",
8
+ "mean_winter_temperature",
9
+ "mean_summer_temperature",
10
+ "mean_annual_temperature",
11
+ "number_of_tropical_nights",
12
+ "maximum_summer_temperature",
13
+ "number_of_days_with_tx_above_30",
14
+ "number_of_days_with_tx_above_35",
15
+ "number_of_days_with_a_dry_ground",
16
+ ]
17
+
18
+ INDICATOR_COLUMNS_PER_TABLE = {
19
+ "total_winter_precipitation": "total_winter_precipitation",
20
+ "total_summer_precipiation": "total_summer_precipitation",
21
+ "total_annual_precipitation": "total_annual_precipitation",
22
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
23
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
24
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
25
+ "mean_winter_temperature": "mean_winter_temperature",
26
+ "mean_summer_temperature": "mean_summer_temperature",
27
+ "mean_annual_temperature": "mean_annual_temperature",
28
+ "number_of_tropical_nights": "number_tropical_nights",
29
+ "maximum_summer_temperature": "maximum_summer_temperature",
30
+ "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
31
+ "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
32
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
33
+ }
climateqa/engine/talk_to_data/main.py CHANGED
@@ -13,8 +13,8 @@ def ask_llm_column_names(sql_query, llm):
13
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
14
  return columns_list
15
 
16
- def ask_drias(db_drias_path:str, query:str, index_state: int = 0, drias_model: str = "ALL"):
17
- final_state = drias_workflow(db_drias_path, query, drias_model)
18
  sql_queries = []
19
  result_dataframes = []
20
  figures = []
@@ -28,10 +28,15 @@ def ask_drias(db_drias_path:str, query:str, index_state: int = 0, drias_model: s
28
  if 'dataframe' in table_state and table_state['dataframe'] is not None:
29
  result_dataframes.append(table_state['dataframe'])
30
  if 'figure' in table_state and table_state['figure'] is not None:
31
- figures.append(table_state['figure'](table_state['dataframe']))
32
 
33
-
34
- return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
 
 
 
 
 
35
 
36
  DRIAS_MODELS = [
37
  'ALL',
 
13
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
14
  return columns_list
15
 
16
+ def ask_drias(query:str, index_state: int = 0):
17
+ final_state = drias_workflow(query)
18
  sql_queries = []
19
  result_dataframes = []
20
  figures = []
 
28
  if 'dataframe' in table_state and table_state['dataframe'] is not None:
29
  result_dataframes.append(table_state['dataframe'])
30
  if 'figure' in table_state and table_state['figure'] is not None:
31
+ figures.append(table_state['figure'])
32
 
33
+ if "error" in final_state and final_state["error"] != "":
34
+ return None, None, None, [], [], [], 0, final_state["error"]
35
+
36
+ sql_query = sql_queries[index_state]
37
+ dataframe = result_dataframes[index_state]
38
+ figure = figures[index_state](dataframe)
39
+ return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, ""
40
 
41
  DRIAS_MODELS = [
42
  'ALL',
climateqa/engine/talk_to_data/plot.py CHANGED
@@ -29,7 +29,6 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
29
  Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
30
  """
31
  indicator = params["indicator_column"]
32
- model = params["model"]
33
  location = params["location"]
34
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
35
 
@@ -43,7 +42,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
43
  Figure: Plotly figure
44
  """
45
  fig = go.Figure()
46
- if model == "ALL":
47
  df_avg = df.groupby("year", as_index=False)[indicator].mean()
48
 
49
  # Transform to list to avoid pandas encoding
@@ -58,8 +57,10 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
58
  .astype(float)
59
  .tolist()
60
  )
 
 
61
  else:
62
- df_model = df[df["model"] == model]
63
 
64
  # Transform to list to avoid pandas encoding
65
  indicators = df_model[indicator].astype(float).tolist()
@@ -73,6 +74,8 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
73
  .astype(float)
74
  .tolist()
75
  )
 
 
76
 
77
  # Indicator per year plot
78
  fig.add_scatter(
@@ -93,7 +96,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
93
  marker=dict(color="#d62728"),
94
  )
95
  fig.update_layout(
96
- title=f"Plot of {indicator_label} in {location} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}",
97
  xaxis_title="Year",
98
  yaxis_title=indicator_label,
99
  template="plotly_white",
@@ -125,7 +128,6 @@ def plot_indicator_number_of_days_per_year_at_location(
125
  """
126
 
127
  indicator = params["indicator_column"]
128
- model = params["model"]
129
  location = params["location"]
130
 
131
  def plot_data(df: pd.DataFrame) -> Figure:
@@ -138,19 +140,21 @@ def plot_indicator_number_of_days_per_year_at_location(
138
  Figure: Plotly figure
139
  """
140
  fig = go.Figure()
141
- if model == "ALL":
142
  df_avg = df.groupby("year", as_index=False)[indicator].mean()
143
 
144
  # Transform to list to avoid pandas encoding
145
  indicators = df_avg[indicator].astype(float).tolist()
146
  years = df_avg["year"].astype(int).tolist()
 
147
 
148
  else:
149
- df_model = df[df["model"] == model]
150
-
151
  # Transform to list to avoid pandas encoding
152
  indicators = df_model[indicator].astype(float).tolist()
153
  years = df_model["year"].astype(int).tolist()
 
 
154
 
155
  # Bar plot
156
  fig.add_trace(
@@ -165,7 +169,7 @@ def plot_indicator_number_of_days_per_year_at_location(
165
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
166
 
167
  fig.update_layout(
168
- title=f"{indicator_label} in {location} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}",
169
  xaxis_title="Year",
170
  yaxis_title=indicator,
171
  yaxis=dict(range=[0, max(indicators)]),
@@ -199,7 +203,6 @@ def plot_distribution_of_indicator_for_given_year(
199
  Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
200
  """
201
  indicator = params["indicator_column"]
202
- model = params["model"]
203
  year = params["year"]
204
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
205
 
@@ -213,18 +216,22 @@ def plot_distribution_of_indicator_for_given_year(
213
  Figure: Plotly figure
214
  """
215
  fig = go.Figure()
216
- if params["model"] == "ALL":
217
  df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
218
  indicator
219
  ].mean()
220
 
221
  # Transform to list to avoid pandas encoding
222
  indicators = df_avg[indicator].astype(float).tolist()
 
 
223
  else:
224
- df_model = df[df["model"] == model]
225
 
226
  # Transform to list to avoid pandas encoding
227
  indicators = df_model[indicator].astype(float).tolist()
 
 
228
 
229
  fig.add_trace(
230
  go.Histogram(
@@ -236,7 +243,7 @@ def plot_distribution_of_indicator_for_given_year(
236
  )
237
 
238
  fig.update_layout(
239
- title=f"Distribution of {indicator_label} in {year} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}",
240
  xaxis_title=indicator_label,
241
  yaxis_title="Frequency",
242
  plot_bgcolor="rgba(0, 0, 0, 0)",
@@ -270,13 +277,12 @@ def plot_map_of_france_of_indicator_for_given_year(
270
  """
271
 
272
  indicator = params["indicator_column"]
273
- model = params["model"]
274
  year = params["year"]
275
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
276
 
277
  def plot_data(df: pd.DataFrame) -> Figure:
278
  fig = go.Figure()
279
- if model == "ALL":
280
  df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
281
  indicator
282
  ].mean()
@@ -284,14 +290,17 @@ def plot_map_of_france_of_indicator_for_given_year(
284
  indicators = df_avg[indicator].astype(float).tolist()
285
  latitudes = df_avg["latitude"].astype(float).tolist()
286
  longitudes = df_avg["longitude"].astype(float).tolist()
 
287
 
288
  else:
289
- df_model = df[df["model"] == model]
290
 
291
  # Transform to list to avoid pandas encoding
292
  indicators = df_model[indicator].astype(float).tolist()
293
  latitudes = df_model["latitude"].astype(float).tolist()
294
  longitudes = df_model["longitude"].astype(float).tolist()
 
 
295
 
296
  fig.add_trace(
297
  go.Scattermapbox(
@@ -314,7 +323,7 @@ def plot_map_of_france_of_indicator_for_given_year(
314
  mapbox_zoom=3,
315
  mapbox_center={"lat": 46.6, "lon": 2.0},
316
  coloraxis_colorbar=dict(title=f"{indicator_label}"), # Add legend
317
- title=f"{indicator_label} in {year} in France {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'} " # Title
318
  )
319
  return fig
320
 
 
29
  Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
30
  """
31
  indicator = params["indicator_column"]
 
32
  location = params["location"]
33
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
34
 
 
42
  Figure: Plotly figure
43
  """
44
  fig = go.Figure()
45
+ if df['model'].nunique() != 1:
46
  df_avg = df.groupby("year", as_index=False)[indicator].mean()
47
 
48
  # Transform to list to avoid pandas encoding
 
57
  .astype(float)
58
  .tolist()
59
  )
60
+ model_label = "Model Average"
61
+
62
  else:
63
+ df_model = df
64
 
65
  # Transform to list to avoid pandas encoding
66
  indicators = df_model[indicator].astype(float).tolist()
 
74
  .astype(float)
75
  .tolist()
76
  )
77
+ model_label = f"Model : {df['model'].unique()[0]}"
78
+
79
 
80
  # Indicator per year plot
81
  fig.add_scatter(
 
96
  marker=dict(color="#d62728"),
97
  )
98
  fig.update_layout(
99
+ title=f"Plot of {indicator_label} in {location} ({model_label})",
100
  xaxis_title="Year",
101
  yaxis_title=indicator_label,
102
  template="plotly_white",
 
128
  """
129
 
130
  indicator = params["indicator_column"]
 
131
  location = params["location"]
132
 
133
  def plot_data(df: pd.DataFrame) -> Figure:
 
140
  Figure: Plotly figure
141
  """
142
  fig = go.Figure()
143
+ if df['model'].nunique() != 1:
144
  df_avg = df.groupby("year", as_index=False)[indicator].mean()
145
 
146
  # Transform to list to avoid pandas encoding
147
  indicators = df_avg[indicator].astype(float).tolist()
148
  years = df_avg["year"].astype(int).tolist()
149
+ model_label = "Model Average"
150
 
151
  else:
152
+ df_model = df
 
153
  # Transform to list to avoid pandas encoding
154
  indicators = df_model[indicator].astype(float).tolist()
155
  years = df_model["year"].astype(int).tolist()
156
+ model_label = f"Model : {df['model'].unique()[0]}"
157
+
158
 
159
  # Bar plot
160
  fig.add_trace(
 
169
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
170
 
171
  fig.update_layout(
172
+ title=f"{indicator_label} in {location} ({model_label})",
173
  xaxis_title="Year",
174
  yaxis_title=indicator,
175
  yaxis=dict(range=[0, max(indicators)]),
 
203
  Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
204
  """
205
  indicator = params["indicator_column"]
 
206
  year = params["year"]
207
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
208
 
 
216
  Figure: Plotly figure
217
  """
218
  fig = go.Figure()
219
+ if df['model'].nunique() != 1:
220
  df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
221
  indicator
222
  ].mean()
223
 
224
  # Transform to list to avoid pandas encoding
225
  indicators = df_avg[indicator].astype(float).tolist()
226
+ model_label = "Model Average"
227
+
228
  else:
229
+ df_model = df
230
 
231
  # Transform to list to avoid pandas encoding
232
  indicators = df_model[indicator].astype(float).tolist()
233
+ model_label = f"Model : {df['model'].unique()[0]}"
234
+
235
 
236
  fig.add_trace(
237
  go.Histogram(
 
243
  )
244
 
245
  fig.update_layout(
246
+ title=f"Distribution of {indicator_label} in {year} ({model_label})",
247
  xaxis_title=indicator_label,
248
  yaxis_title="Frequency",
249
  plot_bgcolor="rgba(0, 0, 0, 0)",
 
277
  """
278
 
279
  indicator = params["indicator_column"]
 
280
  year = params["year"]
281
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
282
 
283
  def plot_data(df: pd.DataFrame) -> Figure:
284
  fig = go.Figure()
285
+ if df['model'].nunique() != 1:
286
  df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
287
  indicator
288
  ].mean()
 
290
  indicators = df_avg[indicator].astype(float).tolist()
291
  latitudes = df_avg["latitude"].astype(float).tolist()
292
  longitudes = df_avg["longitude"].astype(float).tolist()
293
+ model_label = "Model Average"
294
 
295
  else:
296
+ df_model = df
297
 
298
  # Transform to list to avoid pandas encoding
299
  indicators = df_model[indicator].astype(float).tolist()
300
  latitudes = df_model["latitude"].astype(float).tolist()
301
  longitudes = df_model["longitude"].astype(float).tolist()
302
+ model_label = f"Model : {df['model'].unique()[0]}"
303
+
304
 
305
  fig.add_trace(
306
  go.Scattermapbox(
 
323
  mapbox_zoom=3,
324
  mapbox_center={"lat": 46.6, "lon": 2.0},
325
  coloraxis_colorbar=dict(title=f"{indicator_label}"), # Add legend
326
+ title=f"{indicator_label} in {year} in France ({model_label}) " # Title
327
  )
328
  return fig
329
 
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -1,41 +1,23 @@
1
- import sqlite3
2
- from typing import Any, TypedDict
 
3
 
4
-
5
- class SqlQueryOutput(TypedDict):
6
- labels: list[str]
7
- data: list[list[Any]]
8
-
9
-
10
- def execute_sql_query(db_path: str, sql_query: str) -> SqlQueryOutput:
11
  """Execute the SQL Query on the sqlite database
12
 
13
  Args:
14
- db_ (str): path to the sqlite database
15
  sql_query (str): sql query to execute
16
 
17
  Returns:
18
  SqlQueryOutput: labels of the selected column and fetched data
19
  """
20
 
21
- # Connect to sqlite3 database
22
- conn = sqlite3.connect(db_path)
23
- cursor = conn.cursor()
24
 
25
  # Execute the query
26
- cursor.execute(sql_query)
27
-
28
- # Fetch labels of selected columns
29
- labels = [desc[0] for desc in cursor.description]
30
 
31
- # Fetch data
32
- data = cursor.fetchall()
33
- conn.close()
34
-
35
- return {
36
- "labels": labels,
37
- "data": data,
38
- }
39
 
40
 
41
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
@@ -60,15 +42,13 @@ def indicator_per_year_at_location_query(
60
  indicator_column = params.get("indicator_column")
61
  latitude = params.get("latitude")
62
  longitude = params.get("longitude")
63
- model = params.get('model')
64
 
65
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
66
  return ""
67
 
68
- if model == 'ALL':
69
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
70
- else:
71
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nAnd model = '{model}' \nOrder by Year"
72
 
73
  return sql_query
74
 
@@ -91,12 +71,10 @@ def indicator_for_given_year_query(
91
  """
92
  indicator_column = params.get("indicator_column")
93
  year = params.get('year')
94
- model = params.get('model')
95
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
96
  return ""
97
 
98
- if model == 'ALL':
99
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
100
- else:
101
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nAnd model = '{model}'"
102
  return sql_query
 
1
+ from typing import TypedDict
2
+ import duckdb
3
+ import pandas as pd
4
 
5
+ def execute_sql_query(sql_query: str) -> pd.DataFrame:
 
 
 
 
 
 
6
  """Execute the SQL Query on the sqlite database
7
 
8
  Args:
 
9
  sql_query (str): sql query to execute
10
 
11
  Returns:
12
  SqlQueryOutput: labels of the selected column and fetched data
13
  """
14
 
 
 
 
15
 
16
  # Execute the query
17
+ results = duckdb.sql(sql_query)
 
 
 
18
 
19
+ # return fetched data
20
+ return results.fetchdf()
 
 
 
 
 
 
21
 
22
 
23
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
 
42
  indicator_column = params.get("indicator_column")
43
  latitude = params.get("latitude")
44
  longitude = params.get("longitude")
 
45
 
46
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
47
  return ""
48
 
49
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
50
+
51
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
 
52
 
53
  return sql_query
54
 
 
71
  """
72
  indicator_column = params.get("indicator_column")
73
  year = params.get('year')
 
74
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
75
  return ""
76
 
77
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
78
+
79
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
 
80
  return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -1,11 +1,10 @@
1
  import re
2
  from typing import Annotated, TypedDict
3
-
4
- from sympy import use
5
  from geopy.geocoders import Nominatim
6
- import sqlite3
7
  import ast
8
  from climateqa.engine.llm import get_llm
 
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
  from langchain_core.prompts import ChatPromptTemplate
11
 
@@ -35,7 +34,7 @@ class ArrayOutput(TypedDict):
35
 
36
  array: Annotated[str, ..., "Syntactically valid python array."]
37
 
38
- def detect_year_with_openai(sentence: str):
39
  """
40
  Detects years in a sentence using OpenAI's API via LangChain.
41
  """
@@ -56,7 +55,7 @@ def detect_year_with_openai(sentence: str):
56
  if len(years_list) > 0:
57
  return years_list[0]
58
  else:
59
- return None
60
 
61
 
62
  def detectTable(sql_query):
@@ -81,24 +80,26 @@ def coords2loc(coords: tuple):
81
  return "Unknown Location"
82
 
83
 
84
- def nearestNeighbourSQL(db: str, location: tuple, table: str):
85
- conn = sqlite3.connect(db)
86
  long = round(location[1], 3)
87
  lat = round(location[0], 3)
88
- cursor = conn.cursor()
89
- cursor.execute(
 
 
90
  f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
91
- )
 
 
 
92
  # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
93
- results = cursor.fetchall()
94
- return results[0]
95
 
96
 
97
- def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
98
  """Detect relevant tables regarding the plot and the user input
99
 
100
  Args:
101
- db (str): database path
102
  user_question (str): initial user input
103
  plot (Plot): plot object for which we wanna plot
104
  llm (_type_): LLM
@@ -106,19 +107,21 @@ def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list
106
  Returns:
107
  list[str]: list of table names
108
  """
109
- conn = sqlite3.connect(db)
110
- cursor = conn.cursor()
111
 
112
  # Get all table names
113
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
114
- table_names_list = cursor.fetchall()
115
 
116
  prompt = (
117
  f"You are helping to build a plot following this description : {plot['description']}."
 
118
  f"Based on the description of the plot, which table are appropriate for that kind of plot."
119
- f"The different tables are {table_names_list}."
120
- f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
 
 
121
  )
 
 
122
  table_names = ast.literal_eval(
123
  llm.invoke(prompt).content.strip("```python\n").strip()
124
  )
@@ -141,17 +144,28 @@ def detect_relevant_plots(user_question: str, llm):
141
  plots_description += " - Description: " + plot["description"] + "\n"
142
 
143
  prompt = (
144
- f"You are helping to answer a question with insightful visualizations. "
145
- f"Given a list of plots with their name and description: "
146
- f"{plots_description} "
147
- f"The user question is: {user_question}. "
148
- f"Choose the most relevant plots to answer the question. "
149
- f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
150
- f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
151
  )
152
-
153
- response = llm.invoke(prompt).content
154
- return eval(response)
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  # Next Version
 
1
  import re
2
  from typing import Annotated, TypedDict
3
+ import duckdb
 
4
  from geopy.geocoders import Nominatim
 
5
  import ast
6
  from climateqa.engine.llm import get_llm
7
+ from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
  from langchain_core.prompts import ChatPromptTemplate
10
 
 
34
 
35
  array: Annotated[str, ..., "Syntactically valid python array."]
36
 
37
+ def detect_year_with_openai(sentence: str) -> str:
38
  """
39
  Detects years in a sentence using OpenAI's API via LangChain.
40
  """
 
55
  if len(years_list) > 0:
56
  return years_list[0]
57
  else:
58
+ return ""
59
 
60
 
61
  def detectTable(sql_query):
 
80
  return "Unknown Location"
81
 
82
 
83
+ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
 
84
  long = round(location[1], 3)
85
  lat = round(location[0], 3)
86
+
87
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
88
+
89
+ results = duckdb.sql(
90
  f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
91
+ ).fetchdf()
92
+
93
+ if len(results) == 0:
94
+ return "", ""
95
  # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
96
+ return results['latitude'].iloc[0], results['longitude'].iloc[0]
 
97
 
98
 
99
+ def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
100
  """Detect relevant tables regarding the plot and the user input
101
 
102
  Args:
 
103
  user_question (str): initial user input
104
  plot (Plot): plot object for which we wanna plot
105
  llm (_type_): LLM
 
107
  Returns:
108
  list[str]: list of table names
109
  """
 
 
110
 
111
  # Get all table names
112
+ table_names_list = DRIAS_TABLES
 
113
 
114
  prompt = (
115
  f"You are helping to build a plot following this description : {plot['description']}."
116
+ f"You are given a list of tables and a user question."
117
  f"Based on the description of the plot, which table are appropriate for that kind of plot."
118
+ f"Write the 3 most relevant tables to use. Answer only a python list of table name."
119
+ f"### List of tables : {table_names_list}"
120
+ f"### User question : {user_question}"
121
+ f"### List of table name : "
122
  )
123
+
124
+
125
  table_names = ast.literal_eval(
126
  llm.invoke(prompt).content.strip("```python\n").strip()
127
  )
 
144
  plots_description += " - Description: " + plot["description"] + "\n"
145
 
146
  prompt = (
147
+ f"You are helping to answer a quesiton with insightful visualizations."
148
+ f"You are given an user question and a list of plots with their name and description."
149
+ f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
150
+ f"Write the most relevant tables to use. Answer only a python list of plot name."
151
+ f"### Descriptions of the plots : {plots_description}"
152
+ f"### User question : {user_question}"
153
+ f"### Name of the plot : "
154
  )
155
+ # prompt = (
156
+ # f"You are helping to answer a question with insightful visualizations. "
157
+ # f"Given a list of plots with their name and description: "
158
+ # f"{plots_description} "
159
+ # f"The user question is: {user_question}. "
160
+ # f"Choose the most relevant plots to answer the question. "
161
+ # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
162
+ # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
163
+ # )
164
+
165
+ plot_names = ast.literal_eval(
166
+ llm.invoke(prompt).content.strip("```python\n").strip()
167
+ )
168
+ return plot_names
169
 
170
 
171
  # Next Version
climateqa/engine/talk_to_data/workflow.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
 
6
  from plotly.graph_objects import Figure
7
  from climateqa.engine.llm import get_llm
 
8
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
10
  from climateqa.engine.talk_to_data.utils import (
@@ -37,12 +38,12 @@ class State(TypedDict):
37
  user_input: str
38
  plots: list[str]
39
  plot_states: dict[str, PlotState]
 
40
 
41
- def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
42
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
43
 
44
  Args:
45
- db_drias_path (str): path to the drias database
46
  user_input (str): initial user input
47
 
48
  Returns:
@@ -60,8 +61,12 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
60
  state['plots'] = plots
61
 
62
  if not state['plots']:
 
63
  return state
64
 
 
 
 
65
  for plot_name in state['plots']:
66
 
67
  plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
@@ -76,21 +81,23 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
76
 
77
  plot_state['plot_name'] = plot_name
78
 
79
- relevant_tables = find_relevant_tables_per_plot(state, plot, db_drias_path, llm)
 
 
80
 
81
  plot_state['tables'] = relevant_tables
82
 
83
- for table in plot_state['tables']:
 
 
 
84
  table_state: TableState = {
85
  'table_name': table,
86
  'params': {},
87
  'status': 'OK'
88
  }
89
- table_state['params'] = {
90
- 'model': model
91
- }
92
  for param_name in plot['params']:
93
- param = find_param(state, param_name, table, db_drias_path)
94
  if param:
95
  table_state['params'].update(param)
96
 
@@ -99,17 +106,30 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
99
  if sql_query == "":
100
  table_state['status'] = 'ERROR'
101
  continue
 
 
102
 
103
  table_state['sql_query'] = sql_query
104
- results = execute_sql_query(db_drias_path, sql_query)
 
 
 
105
 
106
- df = pd.DataFrame(results['data'], columns=results['labels'])
107
  figure = plot['plot_function'](table_state['params'])
108
  table_state['dataframe'] = df
109
  table_state['figure'] = figure
110
  plot_state['table_states'][table] = table_state
111
 
112
  state['plot_states'][plot_name] = plot_state
 
 
 
 
 
 
 
 
 
113
  return state
114
 
115
 
@@ -118,26 +138,25 @@ def find_relevant_plots(state: State, llm) -> list[str]:
118
  relevant_plots = detect_relevant_plots(state['user_input'], llm)
119
  return relevant_plots
120
 
121
- def find_relevant_tables_per_plot(state: State, plot: Plot, db_path: str, llm) -> list[str]:
122
  print(f"---- Find relevant tables for {plot['name']} ----")
123
- relevant_tables = detect_relevant_tables(db_path, state['user_input'], plot, llm)
124
  return relevant_tables
125
 
126
 
127
- def find_param(state: State, param_name:str, table: str, db_path: str) -> dict[str, Any] | None:
128
  """Perform the good method to retrieve the desired parameter
129
 
130
  Args:
131
  state (State): state of the workflow
132
  param_name (str): name of the desired parameter
133
  table (str): name of the table
134
- db_path (str): path to the databse
135
 
136
  Returns:
137
  dict[str, Any] | None:
138
  """
139
  if param_name == 'location':
140
- location = find_location(state['user_input'], table, db_path)
141
  return location
142
  if param_name == 'indicator_column':
143
  indicator_column = find_indicator_column(table)
@@ -153,13 +172,13 @@ class Location(TypedDict):
153
  latitude: NotRequired[str]
154
  longitude: NotRequired[str]
155
 
156
- def find_location(user_input: str, table: str, db_path: str) -> Location:
157
  print(f"---- Find location in table {table} ----")
158
  location = detect_location_with_openai(user_input)
159
  output: Location = {'location' : location}
160
  if location:
161
  coords = loc2coords(location)
162
- neighbour = nearestNeighbourSQL(db_path, coords, table)
163
  output.update({
164
  "latitude": neighbour[0],
165
  "longitude": neighbour[1],
@@ -182,23 +201,8 @@ def find_indicator_column(table: str) -> str:
182
  """
183
 
184
  print(f"---- Find indicator column in table {table} ----")
185
- indicator_columns_per_table = {
186
- "total_winter_precipitation": "total_winter_precipitation",
187
- "total_summer_precipiation": "total_summer_precipitation",
188
- "total_annual_precipitation": "total_annual_precipitation",
189
- "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
190
- "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
191
- "extreme_precipitation_intensity": "extreme_precipitation_intensity",
192
- "mean_winter_temperature": "mean_winter_temperature",
193
- "mean_summer_temperature": "mean_summer_temperature",
194
- "mean_annual_temperature": "mean_annual_temperature",
195
- "number_of_tropical_nights": "number_tropical_nights",
196
- "maximum_summer_temperature": "maximum_summer_temperature",
197
- "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
198
- "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
199
- "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
200
- }
201
- return indicator_columns_per_table[table]
202
 
203
 
204
  # def make_write_query_node():
 
5
 
6
  from plotly.graph_objects import Figure
7
  from climateqa.engine.llm import get_llm
8
+ from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
11
  from climateqa.engine.talk_to_data.utils import (
 
38
  user_input: str
39
  plots: list[str]
40
  plot_states: dict[str, PlotState]
41
+ error: NotRequired[str]
42
 
43
+ def drias_workflow(user_input: str) -> State:
44
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
45
 
46
  Args:
 
47
  user_input (str): initial user input
48
 
49
  Returns:
 
61
  state['plots'] = plots
62
 
63
  if not state['plots']:
64
+ state['error'] = 'There is no plot to answer to the question'
65
  return state
66
 
67
+ have_relevant_table = False
68
+ have_sql_query = False
69
+ have_dataframe = False
70
  for plot_name in state['plots']:
71
 
72
  plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
 
81
 
82
  plot_state['plot_name'] = plot_name
83
 
84
+ relevant_tables = find_relevant_tables_per_plot(state, plot, llm)
85
+ if len(relevant_tables) > 0 :
86
+ have_relevant_table = True
87
 
88
  plot_state['tables'] = relevant_tables
89
 
90
+ for n, table in enumerate(plot_state['tables']):
91
+ if n > 2:
92
+ break
93
+
94
  table_state: TableState = {
95
  'table_name': table,
96
  'params': {},
97
  'status': 'OK'
98
  }
 
 
 
99
  for param_name in plot['params']:
100
+ param = find_param(state, param_name, table)
101
  if param:
102
  table_state['params'].update(param)
103
 
 
106
  if sql_query == "":
107
  table_state['status'] = 'ERROR'
108
  continue
109
+ else :
110
+ have_sql_query = True
111
 
112
  table_state['sql_query'] = sql_query
113
+ df = execute_sql_query(sql_query)
114
+
115
+ if len(df) > 0:
116
+ have_dataframe = True
117
 
 
118
  figure = plot['plot_function'](table_state['params'])
119
  table_state['dataframe'] = df
120
  table_state['figure'] = figure
121
  plot_state['table_states'][table] = table_state
122
 
123
  state['plot_states'][plot_name] = plot_state
124
+
125
+ if not have_relevant_table:
126
+ state['error'] = "There is no relevant table in the our database to answer your question"
127
+ elif not have_sql_query:
128
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
129
+ elif not have_dataframe:
130
+ state['error'] = "There is no data in our table that can answer to your question"
131
+
132
+
133
  return state
134
 
135
 
 
138
  relevant_plots = detect_relevant_plots(state['user_input'], llm)
139
  return relevant_plots
140
 
141
+ def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
142
  print(f"---- Find relevant tables for {plot['name']} ----")
143
+ relevant_tables = detect_relevant_tables(state['user_input'], plot, llm)
144
  return relevant_tables
145
 
146
 
147
+ def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
148
  """Perform the good method to retrieve the desired parameter
149
 
150
  Args:
151
  state (State): state of the workflow
152
  param_name (str): name of the desired parameter
153
  table (str): name of the table
 
154
 
155
  Returns:
156
  dict[str, Any] | None:
157
  """
158
  if param_name == 'location':
159
+ location = find_location(state['user_input'], table)
160
  return location
161
  if param_name == 'indicator_column':
162
  indicator_column = find_indicator_column(table)
 
172
  latitude: NotRequired[str]
173
  longitude: NotRequired[str]
174
 
175
+ def find_location(user_input: str, table: str) -> Location:
176
  print(f"---- Find location in table {table} ----")
177
  location = detect_location_with_openai(user_input)
178
  output: Location = {'location' : location}
179
  if location:
180
  coords = loc2coords(location)
181
+ neighbour = nearestNeighbourSQL(coords, table)
182
  output.update({
183
  "latitude": neighbour[0],
184
  "longitude": neighbour[1],
 
201
  """
202
 
203
  print(f"---- Find indicator column in table {table} ----")
204
+
205
+ return INDICATOR_COLUMNS_PER_TABLE[table]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  # def make_write_query_node():
style.css CHANGED
@@ -644,17 +644,23 @@ a {
644
  overflow-y:scroll;
645
  }
646
 
 
 
 
 
647
  #sql-query span{
648
  display: none;
649
  }
650
  div#tab-vanna{
651
  max-height: 100¨vh;
652
- overflow-y:scroll;
653
  }
654
  #vanna-plot{
655
  max-height:500px
656
  }
657
 
658
- #drias-model{
659
- max-width: 25%;
 
 
660
  }
 
644
  overflow-y:scroll;
645
  }
646
 
647
+ #sql-query textarea{
648
+ min-height: 100px !important;
649
+ }
650
+
651
  #sql-query span{
652
  display: none;
653
  }
654
  div#tab-vanna{
655
  max-height: 100¨vh;
656
+ overflow-y: hidden;
657
  }
658
  #vanna-plot{
659
  max-height:500px
660
  }
661
 
662
+ #pagination-display{
663
+ text-align: center;
664
+ font-weight: bold;
665
+ font-size: 16px;
666
  }