File size: 5,202 Bytes
f576373 28684d8 d8a656a 28684d8 d8a656a 28684d8 e57556f d8a656a e57556f f576373 f688422 f576373 31a6f8f f688422 31a6f8f 26bb643 c2bf2c8 26bb643 f688422 f576373 8bd064f f576373 28684d8 f576373 28684d8 f576373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from climateqa.engine.talk_to_data.workflow import drias_workflow
from climateqa.engine.llm import get_llm
import ast
llm = get_llm(provider="openai")
def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
"""Adds table names to the SQL query result rows using LLM.
This function modifies the SQL query to include the source table name in each row
of the result set, making it easier to track which data comes from which table.
Args:
sql_query (str): The original SQL query to modify
llm: The language model instance to use for generating the modified query
Returns:
str: The modified SQL query with table names included in the result rows
"""
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
return sql_with_table_names
def ask_llm_column_names(sql_query: str, llm) -> list[str]:
"""Extracts column names from a SQL query using LLM.
This function analyzes a SQL query to identify which columns are being selected
in the result set.
Args:
sql_query (str): The SQL query to analyze
llm: The language model instance to use for column extraction
Returns:
list[str]: A list of column names being selected in the query
"""
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
return columns_list
async def ask_drias(query: str, index_state: int = 0) -> tuple:
"""Main function to process a DRIAS query and return results.
This function orchestrates the DRIAS workflow, processing a user query to generate
SQL queries, dataframes, and visualizations. It handles multiple results and allows
pagination through them.
Args:
query (str): The user's question about climate data
index_state (int, optional): The index of the result to return. Defaults to 0.
Returns:
tuple: A tuple containing:
- sql_query (str): The SQL query used
- dataframe (pd.DataFrame): The resulting data
- figure (Callable): Function to generate the visualization
- sql_queries (list): All generated SQL queries
- result_dataframes (list): All resulting dataframes
- figures (list): All figure generation functions
- index_state (int): Current result index
- table_list (list): List of table names used
- error (str): Error message if any
"""
final_state = await drias_workflow(query)
sql_queries = []
result_dataframes = []
figures = []
table_list = []
for plot_state in final_state['plot_states'].values():
for table_state in plot_state['table_states'].values():
if table_state['status'] == 'OK':
if 'table_name' in table_state:
table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
if 'sql_query' in table_state and table_state['sql_query'] is not None:
sql_queries.append(table_state['sql_query'])
if 'dataframe' in table_state and table_state['dataframe'] is not None:
result_dataframes.append(table_state['dataframe'])
if 'figure' in table_state and table_state['figure'] is not None:
figures.append(table_state['figure'])
if "error" in final_state and final_state["error"] != "":
return None, None, None, [], [], [], 0, final_state["error"]
sql_query = sql_queries[index_state]
dataframe = result_dataframes[index_state]
figure = figures[index_state](dataframe)
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
# def ask_vanna(vn,db_vanna_path, query):
# try :
# location = detect_location_with_openai(query)
# if location:
# coords = loc2coords(location)
# user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
# relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
# coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
# user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
# sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
# return sql_query, result_dataframe, figure
# else :
# empty_df = pd.DataFrame()
# empty_fig = None
# return "", empty_df, empty_fig
# except Exception as e:
# print(f"Error: {e}")
# empty_df = pd.DataFrame()
# empty_fig = None
# return "", empty_df, empty_fig |