from climateqa.engine.talk_to_data.workflow import drias_workflow from climateqa.engine.llm import get_llm import ast llm = get_llm(provider="openai") def ask_llm_to_add_table_names(sql_query: str, llm) -> str: """Adds table names to the SQL query result rows using LLM. This function modifies the SQL query to include the source table name in each row of the result set, making it easier to track which data comes from which table. Args: sql_query (str): The original SQL query to modify llm: The language model instance to use for generating the modified query Returns: str: The modified SQL query with table names included in the result rows """ sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content return sql_with_table_names def ask_llm_column_names(sql_query: str, llm) -> list[str]: """Extracts column names from a SQL query using LLM. This function analyzes a SQL query to identify which columns are being selected in the result set. Args: sql_query (str): The SQL query to analyze llm: The language model instance to use for column extraction Returns: list[str]: A list of column names being selected in the query """ columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content columns_list = ast.literal_eval(columns.strip("```python\n").strip()) return columns_list async def ask_drias(query: str, index_state: int = 0) -> tuple: """Main function to process a DRIAS query and return results. This function orchestrates the DRIAS workflow, processing a user query to generate SQL queries, dataframes, and visualizations. It handles multiple results and allows pagination through them. Args: query (str): The user's question about climate data index_state (int, optional): The index of the result to return. Defaults to 0. Returns: tuple: A tuple containing: - sql_query (str): The SQL query used - dataframe (pd.DataFrame): The resulting data - figure (Callable): Function to generate the visualization - sql_queries (list): All generated SQL queries - result_dataframes (list): All resulting dataframes - figures (list): All figure generation functions - index_state (int): Current result index - table_list (list): List of table names used - error (str): Error message if any """ final_state = await drias_workflow(query) sql_queries = [] result_dataframes = [] figures = [] table_list = [] for plot_state in final_state['plot_states'].values(): for table_state in plot_state['table_states'].values(): if table_state['status'] == 'OK': if 'table_name' in table_state: table_list.append(' '.join(table_state['table_name'].capitalize().split('_'))) if 'sql_query' in table_state and table_state['sql_query'] is not None: sql_queries.append(table_state['sql_query']) if 'dataframe' in table_state and table_state['dataframe'] is not None: result_dataframes.append(table_state['dataframe']) if 'figure' in table_state and table_state['figure'] is not None: figures.append(table_state['figure']) if "error" in final_state and final_state["error"] != "": return None, None, None, [], [], [], 0, final_state["error"] sql_query = sql_queries[index_state] dataframe = result_dataframes[index_state] figure = figures[index_state](dataframe) return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, "" # def ask_vanna(vn,db_vanna_path, query): # try : # location = detect_location_with_openai(query) # if location: # coords = loc2coords(location) # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}") # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm) # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))] # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables) # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False) # return sql_query, result_dataframe, figure # else : # empty_df = pd.DataFrame() # empty_fig = None # return "", empty_df, empty_fig # except Exception as e: # print(f"Error: {e}") # empty_df = pd.DataFrame() # empty_fig = None # return "", empty_df, empty_fig