File size: 5,681 Bytes
711bc31 bc61879 28684d8 711bc31 d8a656a 711bc31 d8a656a 711bc31 d8a656a 711bc31 d8a656a 711bc31 28684d8 711bc31 d8a656a 711bc31 28684d8 711bc31 d8a656a 711bc31 f576373 711bc31 f576373 711bc31 c2bf2c8 26bb643 711bc31 26bb643 f688422 711bc31 bc61879 f688422 711bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
from climateqa.logging import log_drias_interaction_to_huggingface
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
"""Main function to process a DRIAS query and return results.
This function orchestrates the DRIAS workflow, processing a user query to generate
SQL queries, dataframes, and visualizations. It handles multiple results and allows
pagination through them.
Args:
query (str): The user's question about climate data
index_state (int, optional): The index of the result to return. Defaults to 0.
Returns:
tuple: A tuple containing:
- sql_query (str): The SQL query used
- dataframe (pd.DataFrame): The resulting data
- figure (Callable): Function to generate the visualization
- sql_queries (list): All generated SQL queries
- result_dataframes (list): All resulting dataframes
- figures (list): All figure generation functions
- index_state (int): Current result index
- table_list (list): List of table names used
- error (str): Error message if any
"""
final_state = await drias_workflow(query)
sql_queries = []
result_dataframes = []
figures = []
plot_title_list = []
plot_informations = []
for output_title, output in final_state['outputs'].items():
if output['status'] == 'OK':
if output['table'] is not None:
plot_title_list.append(output_title)
if output['plot_information'] is not None:
plot_informations.append(output['plot_information'])
if output['sql_query'] is not None:
sql_queries.append(output['sql_query'])
if output['dataframe'] is not None:
result_dataframes.append(output['dataframe'])
if output['figure'] is not None:
figures.append(output['figure'])
if "error" in final_state and final_state["error"] != "":
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
return None, None, None, None, [], [], [], 0, [], final_state["error"]
sql_query = sql_queries[index_state]
dataframe = result_dataframes[index_state]
figure = figures[index_state](dataframe)
plot_information = plot_informations[index_state]
log_drias_interaction_to_huggingface(query, sql_query, user_id)
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
"""Main function to process a DRIAS query and return results.
This function orchestrates the DRIAS workflow, processing a user query to generate
SQL queries, dataframes, and visualizations. It handles multiple results and allows
pagination through them.
Args:
query (str): The user's question about climate data
index_state (int, optional): The index of the result to return. Defaults to 0.
Returns:
tuple: A tuple containing:
- sql_query (str): The SQL query used
- dataframe (pd.DataFrame): The resulting data
- figure (Callable): Function to generate the visualization
- sql_queries (list): All generated SQL queries
- result_dataframes (list): All resulting dataframes
- figures (list): All figure generation functions
- index_state (int): Current result index
- table_list (list): List of table names used
- error (str): Error message if any
"""
final_state = await ipcc_workflow(query)
sql_queries = []
result_dataframes = []
figures = []
plot_title_list = []
plot_informations = []
for output_title, output in final_state['outputs'].items():
if output['status'] == 'OK':
if output['table'] is not None:
plot_title_list.append(output_title)
if output['plot_information'] is not None:
plot_informations.append(output['plot_information'])
if output['sql_query'] is not None:
sql_queries.append(output['sql_query'])
if output['dataframe'] is not None:
result_dataframes.append(output['dataframe'])
if output['figure'] is not None:
figures.append(output['figure'])
if "error" in final_state and final_state["error"] != "":
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
return None, None, None, None, [], [], [], 0, [], final_state["error"]
sql_query = sql_queries[index_state]
dataframe = result_dataframes[index_state]
figure = figures[index_state](dataframe)
plot_information = plot_informations[index_state]
log_drias_interaction_to_huggingface(query, sql_query, user_id)
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|