import os from typing import Any, Callable, TypedDict, Optional from numpy import sort import pandas as pd import asyncio from plotly.graph_objects import Figure from climateqa.engine.llm import get_llm from climateqa.engine.talk_to_data import sql_query from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE from climateqa.engine.talk_to_data.plot import PLOTS, Plot from climateqa.engine.talk_to_data.sql_query import execute_sql_query from climateqa.engine.talk_to_data.utils import ( detect_relevant_plots, detect_year_with_openai, loc2coords, detect_location_with_openai, nearestNeighbourSQL, detect_relevant_tables, ) ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd())) class TableState(TypedDict): """Represents the state of a table in the DRIAS workflow. This class defines the structure for tracking the state of a table during the data processing workflow, including its name, parameters, SQL query, and results. Attributes: table_name (str): The name of the table in the database params (dict[str, Any]): Parameters used for querying the table sql_query (str, optional): The SQL query used to fetch data dataframe (pd.DataFrame | None, optional): The resulting data figure (Callable[..., Figure], optional): Function to generate visualization status (str): The current status of the table processing ('OK' or 'ERROR') """ table_name: str params: dict[str, Any] sql_query: Optional[str] dataframe: Optional[pd.DataFrame | None] figure: Optional[Callable[..., Figure]] status: str class PlotState(TypedDict): """Represents the state of a plot in the DRIAS workflow. This class defines the structure for tracking the state of a plot during the data processing workflow, including its name and associated tables. Attributes: plot_name (str): The name of the plot tables (list[str]): List of tables used in the plot table_states (dict[str, TableState]): States of the tables used in the plot """ plot_name: str tables: list[str] table_states: dict[str, TableState] class State(TypedDict): user_input: str plots: list[str] plot_states: dict[str, PlotState] error: Optional[str] async def find_relevant_plots(state: State, llm) -> list[str]: print("---- Find relevant plots ----") relevant_plots = await detect_relevant_plots(state['user_input'], llm) return relevant_plots async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]: print(f"---- Find relevant tables for {plot['name']} ----") relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm) return relevant_tables async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None: """Perform the good method to retrieve the desired parameter Args: state (State): state of the workflow param_name (str): name of the desired parameter table (str): name of the table Returns: dict[str, Any] | None: """ if param_name == 'location': location = await find_location(state['user_input'], table) return location if param_name == 'year': year = await find_year(state['user_input']) return {'year': year} return None class Location(TypedDict): location: str latitude: Optional[str] longitude: Optional[str] async def find_location(user_input: str, table: str) -> Location: print(f"---- Find location in table {table} ----") location = await detect_location_with_openai(user_input) output: Location = {'location' : location} if location: coords = loc2coords(location) neighbour = nearestNeighbourSQL(coords, table) output.update({ "latitude": neighbour[0], "longitude": neighbour[1], }) return output async def find_year(user_input: str) -> str: """Extracts year information from user input using LLM. This function uses an LLM to identify and extract year information from the user's query, which is used to filter data in subsequent queries. Args: user_input (str): The user's query text Returns: str: The extracted year, or empty string if no year found """ print(f"---- Find year ---") year = await detect_year_with_openai(user_input) return year def find_indicator_column(table: str) -> str: """Retrieves the name of the indicator column within a table. This function maps table names to their corresponding indicator columns using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE. Args: table (str): Name of the table in the database Returns: str: Name of the indicator column for the specified table Raises: KeyError: If the table name is not found in the mapping """ print(f"---- Find indicator column in table {table} ----") return INDICATOR_COLUMNS_PER_TABLE[table] async def process_table( table: str, params: dict[str, Any], plot: Plot, ) -> TableState: """Processes a table to extract relevant data and generate visualizations. This function retrieves the SQL query for the specified table, executes it, and generates a visualization based on the results. Args: table (str): The name of the table to process params (dict[str, Any]): Parameters used for querying the table plot (Plot): The plot object containing SQL query and visualization function Returns: TableState: The state of the processed table """ table_state: TableState = { 'table_name': table, 'params': params.copy(), 'status': 'OK', 'dataframe': None, 'sql_query': None, 'figure': None } table_state['params']['indicator_column'] = find_indicator_column(table) sql_query = plot['sql_query'](table, table_state['params']) if sql_query == "": table_state['status'] = 'ERROR' return table_state table_state['sql_query'] = sql_query df = await execute_sql_query(sql_query) table_state['dataframe'] = df table_state['figure'] = plot['plot_function'](table_state['params']) return table_state async def drias_workflow(user_input: str) -> State: """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated Args: user_input (str): initial user input Returns: State: Final state with all the results """ state: State = { 'user_input': user_input, 'plots': [], 'plot_states': {}, 'error': '' } llm = get_llm(provider="openai") plots = await find_relevant_plots(state, llm) state['plots'] = plots if len(state['plots']) < 1: state['error'] = 'There is no plot to answer to the question' return state have_relevant_table = False have_sql_query = False have_dataframe = False for plot_name in state['plots']: plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object if plot is None: continue plot_state: PlotState = { 'plot_name': plot_name, 'tables': [], 'table_states': {} } plot_state['plot_name'] = plot_name relevant_tables = await find_relevant_tables_per_plot(state, plot, llm) if len(relevant_tables) > 0 : have_relevant_table = True plot_state['tables'] = relevant_tables params = {} for param_name in plot['params']: param = await find_param(state, param_name, relevant_tables[0]) if param: params.update(param) tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]] results = await asyncio.gather(*tasks) # Store results back in plot_state have_dataframe = False have_sql_query = False for table_state in results: if table_state['sql_query']: have_sql_query = True if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0: have_dataframe = True plot_state['table_states'][table_state['table_name']] = table_state state['plot_states'][plot_name] = plot_state if not have_relevant_table: state['error'] = "There is no relevant table in our database to answer your question" elif not have_sql_query: state['error'] = "There is no relevant sql query on our database that can help to answer your question" elif not have_dataframe: state['error'] = "There is no data in our table that can answer to your question" return state # def make_write_query_node(): # def write_query(state): # print("---- Write query ----") # for table in state["tables"]: # sql_query = QUERIES[state[table]['query_type']]( # table=table, # indicator_column=state[table]["columns"], # longitude=state[table]["longitude"], # latitude=state[table]["latitude"], # ) # state[table].update({"sql_query": sql_query}) # return state # return write_query # def make_fetch_data_node(db_path): # def fetch_data(state): # print("---- Fetch data ----") # for table in state["tables"]: # results = execute_sql_query(db_path, state[table]['sql_query']) # state[table].update(results) # return state # return fetch_data ## V2 # def make_fetch_data_node(db_path: str, llm): # def fetch_data(state): # print("---- Fetch data ----") # db = SQLDatabase.from_uri(f"sqlite:///{db_path}") # output = {} # sql_query = write_sql_query(state["query"], db, state["tables"], llm) # # TO DO : Add query checker # print(f"SQL query : {sql_query}") # output["sql_query"] = sql_query # output.update(fetch_data_from_sql_query(db_path, sql_query)) # return output # return fetch_data