armanddemasson's picture
chore: remove prints in talk to drias workflow
1bd73b6
import os
from typing import Any, Callable, TypedDict, Optional
from numpy import sort
import pandas as pd
import asyncio
from plotly.graph_objects import Figure
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data import sql_query
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
from climateqa.engine.talk_to_data.utils import (
detect_relevant_plots,
detect_year_with_openai,
loc2coords,
detect_location_with_openai,
nearestNeighbourSQL,
detect_relevant_tables,
)
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
class TableState(TypedDict):
"""Represents the state of a table in the DRIAS workflow.
This class defines the structure for tracking the state of a table during the
data processing workflow, including its name, parameters, SQL query, and results.
Attributes:
table_name (str): The name of the table in the database
params (dict[str, Any]): Parameters used for querying the table
sql_query (str, optional): The SQL query used to fetch data
dataframe (pd.DataFrame | None, optional): The resulting data
figure (Callable[..., Figure], optional): Function to generate visualization
status (str): The current status of the table processing ('OK' or 'ERROR')
"""
table_name: str
params: dict[str, Any]
sql_query: Optional[str]
dataframe: Optional[pd.DataFrame | None]
figure: Optional[Callable[..., Figure]]
status: str
class PlotState(TypedDict):
"""Represents the state of a plot in the DRIAS workflow.
This class defines the structure for tracking the state of a plot during the
data processing workflow, including its name and associated tables.
Attributes:
plot_name (str): The name of the plot
tables (list[str]): List of tables used in the plot
table_states (dict[str, TableState]): States of the tables used in the plot
"""
plot_name: str
tables: list[str]
table_states: dict[str, TableState]
class State(TypedDict):
user_input: str
plots: list[str]
plot_states: dict[str, PlotState]
error: Optional[str]
async def find_relevant_plots(state: State, llm) -> list[str]:
print("---- Find relevant plots ----")
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
return relevant_plots
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
print(f"---- Find relevant tables for {plot['name']} ----")
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
return relevant_tables
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
"""Perform the good method to retrieve the desired parameter
Args:
state (State): state of the workflow
param_name (str): name of the desired parameter
table (str): name of the table
Returns:
dict[str, Any] | None:
"""
if param_name == 'location':
location = await find_location(state['user_input'], table)
return location
if param_name == 'year':
year = await find_year(state['user_input'])
return {'year': year}
return None
class Location(TypedDict):
location: str
latitude: Optional[str]
longitude: Optional[str]
async def find_location(user_input: str, table: str) -> Location:
print(f"---- Find location in table {table} ----")
location = await detect_location_with_openai(user_input)
output: Location = {'location' : location}
if location:
coords = loc2coords(location)
neighbour = nearestNeighbourSQL(coords, table)
output.update({
"latitude": neighbour[0],
"longitude": neighbour[1],
})
return output
async def find_year(user_input: str) -> str:
"""Extracts year information from user input using LLM.
This function uses an LLM to identify and extract year information from the
user's query, which is used to filter data in subsequent queries.
Args:
user_input (str): The user's query text
Returns:
str: The extracted year, or empty string if no year found
"""
print(f"---- Find year ---")
year = await detect_year_with_openai(user_input)
return year
def find_indicator_column(table: str) -> str:
"""Retrieves the name of the indicator column within a table.
This function maps table names to their corresponding indicator columns
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
Args:
table (str): Name of the table in the database
Returns:
str: Name of the indicator column for the specified table
Raises:
KeyError: If the table name is not found in the mapping
"""
print(f"---- Find indicator column in table {table} ----")
return INDICATOR_COLUMNS_PER_TABLE[table]
async def process_table(
table: str,
params: dict[str, Any],
plot: Plot,
) -> TableState:
"""Processes a table to extract relevant data and generate visualizations.
This function retrieves the SQL query for the specified table, executes it,
and generates a visualization based on the results.
Args:
table (str): The name of the table to process
params (dict[str, Any]): Parameters used for querying the table
plot (Plot): The plot object containing SQL query and visualization function
Returns:
TableState: The state of the processed table
"""
table_state: TableState = {
'table_name': table,
'params': params.copy(),
'status': 'OK',
'dataframe': None,
'sql_query': None,
'figure': None
}
table_state['params']['indicator_column'] = find_indicator_column(table)
sql_query = plot['sql_query'](table, table_state['params'])
if sql_query == "":
table_state['status'] = 'ERROR'
return table_state
table_state['sql_query'] = sql_query
df = await execute_sql_query(sql_query)
table_state['dataframe'] = df
table_state['figure'] = plot['plot_function'](table_state['params'])
return table_state
async def drias_workflow(user_input: str) -> State:
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
Args:
user_input (str): initial user input
Returns:
State: Final state with all the results
"""
state: State = {
'user_input': user_input,
'plots': [],
'plot_states': {},
'error': ''
}
llm = get_llm(provider="openai")
plots = await find_relevant_plots(state, llm)
state['plots'] = plots
if len(state['plots']) < 1:
state['error'] = 'There is no plot to answer to the question'
return state
have_relevant_table = False
have_sql_query = False
have_dataframe = False
for plot_name in state['plots']:
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
if plot is None:
continue
plot_state: PlotState = {
'plot_name': plot_name,
'tables': [],
'table_states': {}
}
plot_state['plot_name'] = plot_name
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
if len(relevant_tables) > 0 :
have_relevant_table = True
plot_state['tables'] = relevant_tables
params = {}
for param_name in plot['params']:
param = await find_param(state, param_name, relevant_tables[0])
if param:
params.update(param)
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
results = await asyncio.gather(*tasks)
# Store results back in plot_state
have_dataframe = False
have_sql_query = False
for table_state in results:
if table_state['sql_query']:
have_sql_query = True
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
have_dataframe = True
plot_state['table_states'][table_state['table_name']] = table_state
state['plot_states'][plot_name] = plot_state
if not have_relevant_table:
state['error'] = "There is no relevant table in our database to answer your question"
elif not have_sql_query:
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
elif not have_dataframe:
state['error'] = "There is no data in our table that can answer to your question"
return state
# def make_write_query_node():
# def write_query(state):
# print("---- Write query ----")
# for table in state["tables"]:
# sql_query = QUERIES[state[table]['query_type']](
# table=table,
# indicator_column=state[table]["columns"],
# longitude=state[table]["longitude"],
# latitude=state[table]["latitude"],
# )
# state[table].update({"sql_query": sql_query})
# return state
# return write_query
# def make_fetch_data_node(db_path):
# def fetch_data(state):
# print("---- Fetch data ----")
# for table in state["tables"]:
# results = execute_sql_query(db_path, state[table]['sql_query'])
# state[table].update(results)
# return state
# return fetch_data
## V2
# def make_fetch_data_node(db_path: str, llm):
# def fetch_data(state):
# print("---- Fetch data ----")
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
# output = {}
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
# # TO DO : Add query checker
# print(f"SQL query : {sql_query}")
# output["sql_query"] = sql_query
# output.update(fetch_data_from_sql_query(db_path, sql_query))
# return output
# return fetch_data