File size: 10,408 Bytes
2b0419c
 
1f4519f
22ffcc6
2b0419c
22ffcc6
2b0419c
 
22ffcc6
2b0419c
 
 
 
 
 
 
 
 
 
 
 
22ffcc6
2b0419c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4519f
 
 
2b0419c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4519f
2b0419c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4519f
 
2b0419c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ffcc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b0419c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import os

from typing import Any, Callable, TypedDict, Optional
from numpy import sort
import pandas as pd
import asyncio
from plotly.graph_objects import Figure
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data import sql_query
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
from climateqa.engine.talk_to_data.utils import (
    detect_relevant_plots,
    detect_year_with_openai,
    loc2coords,
    detect_location_with_openai,
    nearestNeighbourSQL,
    detect_relevant_tables,
)


ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))

class TableState(TypedDict):
    """Represents the state of a table in the DRIAS workflow.
    
    This class defines the structure for tracking the state of a table during the
    data processing workflow, including its name, parameters, SQL query, and results.
    
    Attributes:
        table_name (str): The name of the table in the database
        params (dict[str, Any]): Parameters used for querying the table
        sql_query (str, optional): The SQL query used to fetch data
        dataframe (pd.DataFrame | None, optional): The resulting data
        figure (Callable[..., Figure], optional): Function to generate visualization
        status (str): The current status of the table processing ('OK' or 'ERROR')
    """
    table_name: str
    params: dict[str, Any]
    sql_query: Optional[str]
    dataframe: Optional[pd.DataFrame | None]
    figure: Optional[Callable[..., Figure]]
    status: str

class PlotState(TypedDict):
    """Represents the state of a plot in the DRIAS workflow.
    
    This class defines the structure for tracking the state of a plot during the
    data processing workflow, including its name and associated tables.
    
    Attributes:
        plot_name (str): The name of the plot
        tables (list[str]): List of tables used in the plot
        table_states (dict[str, TableState]): States of the tables used in the plot
    """
    plot_name: str
    tables: list[str]
    table_states: dict[str, TableState]

class State(TypedDict):
    user_input: str
    plots: list[str]
    plot_states: dict[str, PlotState]
    error: Optional[str]
    
async def find_relevant_plots(state: State, llm) -> list[str]:
    print("---- Find relevant plots ----")
    relevant_plots = await detect_relevant_plots(state['user_input'], llm)
    return relevant_plots

async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
    print(f"---- Find relevant tables for {plot['name']} ----")
    relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
    return relevant_tables

async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
    """Perform the good method to retrieve the desired parameter

    Args:
        state (State): state of the workflow
        param_name (str): name of the desired parameter
        table (str): name of the table

    Returns:
        dict[str, Any] | None: 
    """
    if param_name == 'location':
        location = await find_location(state['user_input'], table)
        return location
    if param_name == 'year':
        year = await find_year(state['user_input'])
        return {'year': year}
    return None

class Location(TypedDict):
    location: str
    latitude: Optional[str]
    longitude: Optional[str]

async def find_location(user_input: str, table: str) -> Location:
    print(f"---- Find location in table {table} ----")
    location = await detect_location_with_openai(user_input)
    output: Location = {'location' : location}
    if location:
        coords = loc2coords(location)
        neighbour = nearestNeighbourSQL(coords, table)
        output.update({
            "latitude": neighbour[0],
            "longitude": neighbour[1],
        })
    return output

async def find_year(user_input: str) -> str:
    """Extracts year information from user input using LLM.
    
    This function uses an LLM to identify and extract year information from the
    user's query, which is used to filter data in subsequent queries.
    
    Args:
        user_input (str): The user's query text
        
    Returns:
        str: The extracted year, or empty string if no year found
    """
    print(f"---- Find year ---")
    year = await detect_year_with_openai(user_input)
    return year

def find_indicator_column(table: str) -> str:
    """Retrieves the name of the indicator column within a table.
    
    This function maps table names to their corresponding indicator columns
    using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
    
    Args:
        table (str): Name of the table in the database
        
    Returns:
        str: Name of the indicator column for the specified table
        
    Raises:
        KeyError: If the table name is not found in the mapping
    """
    print(f"---- Find indicator column in table {table} ----")
    return INDICATOR_COLUMNS_PER_TABLE[table]


async def process_table(
    table: str,
    params: dict[str, Any],
    plot: Plot,
) -> TableState:
    """Processes a table to extract relevant data and generate visualizations.
    
    This function retrieves the SQL query for the specified table, executes it,
    and generates a visualization based on the results.
    
    Args:
        table (str): The name of the table to process
        params (dict[str, Any]): Parameters used for querying the table
        plot (Plot): The plot object containing SQL query and visualization function
        
    Returns:
        TableState: The state of the processed table
    """
    table_state: TableState = {
        'table_name': table,
        'params': params.copy(),
        'status': 'OK',
        'dataframe': None,
        'sql_query': None,
        'figure': None
    }

    table_state['params']['indicator_column'] = find_indicator_column(table)
    sql_query = plot['sql_query'](table, table_state['params'])

    if sql_query == "":
        table_state['status'] = 'ERROR'
        return table_state
    table_state['sql_query'] = sql_query
    df = await execute_sql_query(sql_query)

    table_state['dataframe'] = df
    table_state['figure'] = plot['plot_function'](table_state['params'])

    return table_state

async def drias_workflow(user_input: str) -> State:
    """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated

    Args:
        user_input (str): initial user input

    Returns:
        State: Final state with all the results
    """
    state: State = {
        'user_input': user_input,
        'plots': [],
        'plot_states': {},
        'error': ''
    }

    llm = get_llm(provider="openai")

    plots = await find_relevant_plots(state, llm)

    state['plots'] = plots

    if len(state['plots']) < 1:
        state['error'] = 'There is no plot to answer to the question'
        return state

    have_relevant_table = False
    have_sql_query = False
    have_dataframe = False

    for plot_name in state['plots']:
        
        plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
        if plot is None:
            continue
        
        plot_state: PlotState = {
            'plot_name': plot_name,
            'tables': [],
            'table_states': {}
        }

        plot_state['plot_name'] = plot_name

        relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)

        if len(relevant_tables) > 0 :
            have_relevant_table = True

        plot_state['tables'] = relevant_tables

        params = {}
        for param_name in plot['params']:
            param = await find_param(state, param_name, relevant_tables[0])
            if param:
                params.update(param)
        
        tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
        results = await asyncio.gather(*tasks)

        # Store results back in plot_state
        have_dataframe = False
        have_sql_query = False
        for table_state in results:
            if table_state['sql_query']:
                have_sql_query = True
            if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
                have_dataframe = True
            plot_state['table_states'][table_state['table_name']] = table_state

        state['plot_states'][plot_name] = plot_state
                
    if not have_relevant_table:
        state['error'] = "There is no relevant table in our database to answer your question"
    elif not have_sql_query:
        state['error'] = "There is no relevant sql query on our database that can help to answer your question"
    elif not have_dataframe:
        state['error'] = "There is no data in our table that can answer to your question"
    
    return state

# def make_write_query_node():

#     def write_query(state):
#         print("---- Write query ----")
#         for table in state["tables"]:
#             sql_query = QUERIES[state[table]['query_type']](
#                 table=table,
#                 indicator_column=state[table]["columns"],
#                 longitude=state[table]["longitude"],
#                 latitude=state[table]["latitude"],
#             )
#             state[table].update({"sql_query": sql_query})

#         return state

#     return write_query

# def make_fetch_data_node(db_path):

#     def fetch_data(state):
#         print("---- Fetch data ----")
#         for table in state["tables"]:
#             results = execute_sql_query(db_path, state[table]['sql_query'])
#             state[table].update(results)

#         return state
        
#     return fetch_data



## V2


# def make_fetch_data_node(db_path: str, llm):
#     def fetch_data(state):
#         print("---- Fetch data ----")
#         db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
#         output = {}
#         sql_query = write_sql_query(state["query"], db, state["tables"], llm)
#         # TO DO : Add query checker
#         print(f"SQL query  : {sql_query}")
#         output["sql_query"] = sql_query
#         output.update(fetch_data_from_sql_query(db_path, sql_query))
#         return output

#     return fetch_data