|
import os |
|
|
|
from typing import Any, Callable, TypedDict, Optional |
|
from numpy import sort |
|
import pandas as pd |
|
import asyncio |
|
from plotly.graph_objects import Figure |
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.talk_to_data import sql_query |
|
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE |
|
from climateqa.engine.talk_to_data.plot import PLOTS, Plot |
|
from climateqa.engine.talk_to_data.sql_query import execute_sql_query |
|
from climateqa.engine.talk_to_data.utils import ( |
|
detect_relevant_plots, |
|
detect_year_with_openai, |
|
loc2coords, |
|
detect_location_with_openai, |
|
nearestNeighbourSQL, |
|
detect_relevant_tables, |
|
) |
|
|
|
|
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd())) |
|
|
|
class TableState(TypedDict): |
|
"""Represents the state of a table in the DRIAS workflow. |
|
|
|
This class defines the structure for tracking the state of a table during the |
|
data processing workflow, including its name, parameters, SQL query, and results. |
|
|
|
Attributes: |
|
table_name (str): The name of the table in the database |
|
params (dict[str, Any]): Parameters used for querying the table |
|
sql_query (str, optional): The SQL query used to fetch data |
|
dataframe (pd.DataFrame | None, optional): The resulting data |
|
figure (Callable[..., Figure], optional): Function to generate visualization |
|
status (str): The current status of the table processing ('OK' or 'ERROR') |
|
""" |
|
table_name: str |
|
params: dict[str, Any] |
|
sql_query: Optional[str] |
|
dataframe: Optional[pd.DataFrame | None] |
|
figure: Optional[Callable[..., Figure]] |
|
status: str |
|
|
|
class PlotState(TypedDict): |
|
"""Represents the state of a plot in the DRIAS workflow. |
|
|
|
This class defines the structure for tracking the state of a plot during the |
|
data processing workflow, including its name and associated tables. |
|
|
|
Attributes: |
|
plot_name (str): The name of the plot |
|
tables (list[str]): List of tables used in the plot |
|
table_states (dict[str, TableState]): States of the tables used in the plot |
|
""" |
|
plot_name: str |
|
tables: list[str] |
|
table_states: dict[str, TableState] |
|
|
|
class State(TypedDict): |
|
user_input: str |
|
plots: list[str] |
|
plot_states: dict[str, PlotState] |
|
error: Optional[str] |
|
|
|
async def find_relevant_plots(state: State, llm) -> list[str]: |
|
print("---- Find relevant plots ----") |
|
relevant_plots = await detect_relevant_plots(state['user_input'], llm) |
|
return relevant_plots |
|
|
|
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]: |
|
print(f"---- Find relevant tables for {plot['name']} ----") |
|
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm) |
|
return relevant_tables |
|
|
|
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None: |
|
"""Perform the good method to retrieve the desired parameter |
|
|
|
Args: |
|
state (State): state of the workflow |
|
param_name (str): name of the desired parameter |
|
table (str): name of the table |
|
|
|
Returns: |
|
dict[str, Any] | None: |
|
""" |
|
if param_name == 'location': |
|
location = await find_location(state['user_input'], table) |
|
return location |
|
if param_name == 'year': |
|
year = await find_year(state['user_input']) |
|
return {'year': year} |
|
return None |
|
|
|
class Location(TypedDict): |
|
location: str |
|
latitude: Optional[str] |
|
longitude: Optional[str] |
|
|
|
async def find_location(user_input: str, table: str) -> Location: |
|
print(f"---- Find location in table {table} ----") |
|
location = await detect_location_with_openai(user_input) |
|
output: Location = {'location' : location} |
|
if location: |
|
coords = loc2coords(location) |
|
neighbour = nearestNeighbourSQL(coords, table) |
|
output.update({ |
|
"latitude": neighbour[0], |
|
"longitude": neighbour[1], |
|
}) |
|
return output |
|
|
|
async def find_year(user_input: str) -> str: |
|
"""Extracts year information from user input using LLM. |
|
|
|
This function uses an LLM to identify and extract year information from the |
|
user's query, which is used to filter data in subsequent queries. |
|
|
|
Args: |
|
user_input (str): The user's query text |
|
|
|
Returns: |
|
str: The extracted year, or empty string if no year found |
|
""" |
|
print(f"---- Find year ---") |
|
year = await detect_year_with_openai(user_input) |
|
return year |
|
|
|
def find_indicator_column(table: str) -> str: |
|
"""Retrieves the name of the indicator column within a table. |
|
|
|
This function maps table names to their corresponding indicator columns |
|
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE. |
|
|
|
Args: |
|
table (str): Name of the table in the database |
|
|
|
Returns: |
|
str: Name of the indicator column for the specified table |
|
|
|
Raises: |
|
KeyError: If the table name is not found in the mapping |
|
""" |
|
print(f"---- Find indicator column in table {table} ----") |
|
return INDICATOR_COLUMNS_PER_TABLE[table] |
|
|
|
|
|
async def process_table( |
|
table: str, |
|
params: dict[str, Any], |
|
plot: Plot, |
|
) -> TableState: |
|
"""Processes a table to extract relevant data and generate visualizations. |
|
|
|
This function retrieves the SQL query for the specified table, executes it, |
|
and generates a visualization based on the results. |
|
|
|
Args: |
|
table (str): The name of the table to process |
|
params (dict[str, Any]): Parameters used for querying the table |
|
plot (Plot): The plot object containing SQL query and visualization function |
|
|
|
Returns: |
|
TableState: The state of the processed table |
|
""" |
|
table_state: TableState = { |
|
'table_name': table, |
|
'params': params.copy(), |
|
'status': 'OK', |
|
'dataframe': None, |
|
'sql_query': None, |
|
'figure': None |
|
} |
|
|
|
table_state['params']['indicator_column'] = find_indicator_column(table) |
|
sql_query = plot['sql_query'](table, table_state['params']) |
|
|
|
if sql_query == "": |
|
table_state['status'] = 'ERROR' |
|
return table_state |
|
table_state['sql_query'] = sql_query |
|
df = await execute_sql_query(sql_query) |
|
|
|
table_state['dataframe'] = df |
|
table_state['figure'] = plot['plot_function'](table_state['params']) |
|
|
|
return table_state |
|
|
|
async def drias_workflow(user_input: str) -> State: |
|
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated |
|
|
|
Args: |
|
user_input (str): initial user input |
|
|
|
Returns: |
|
State: Final state with all the results |
|
""" |
|
state: State = { |
|
'user_input': user_input, |
|
'plots': [], |
|
'plot_states': {}, |
|
'error': '' |
|
} |
|
|
|
llm = get_llm(provider="openai") |
|
|
|
plots = await find_relevant_plots(state, llm) |
|
|
|
state['plots'] = plots |
|
|
|
if len(state['plots']) < 1: |
|
state['error'] = 'There is no plot to answer to the question' |
|
return state |
|
|
|
have_relevant_table = False |
|
have_sql_query = False |
|
have_dataframe = False |
|
|
|
for plot_name in state['plots']: |
|
|
|
plot = next((p for p in PLOTS if p['name'] == plot_name), None) |
|
if plot is None: |
|
continue |
|
|
|
plot_state: PlotState = { |
|
'plot_name': plot_name, |
|
'tables': [], |
|
'table_states': {} |
|
} |
|
|
|
plot_state['plot_name'] = plot_name |
|
|
|
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm) |
|
|
|
if len(relevant_tables) > 0 : |
|
have_relevant_table = True |
|
|
|
plot_state['tables'] = relevant_tables |
|
|
|
params = {} |
|
for param_name in plot['params']: |
|
param = await find_param(state, param_name, relevant_tables[0]) |
|
if param: |
|
params.update(param) |
|
|
|
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
have_dataframe = False |
|
have_sql_query = False |
|
for table_state in results: |
|
if table_state['sql_query']: |
|
have_sql_query = True |
|
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0: |
|
have_dataframe = True |
|
plot_state['table_states'][table_state['table_name']] = table_state |
|
|
|
state['plot_states'][plot_name] = plot_state |
|
|
|
if not have_relevant_table: |
|
state['error'] = "There is no relevant table in our database to answer your question" |
|
elif not have_sql_query: |
|
state['error'] = "There is no relevant sql query on our database that can help to answer your question" |
|
elif not have_dataframe: |
|
state['error'] = "There is no data in our table that can answer to your question" |
|
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|