File size: 5,202 Bytes
f576373
28684d8
 
 
 
 
d8a656a
 
 
 
 
 
 
 
 
 
 
 
 
28684d8
 
 
d8a656a
 
 
 
 
 
 
 
 
 
 
 
 
28684d8
 
 
 
e57556f
d8a656a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e57556f
f576373
 
 
f688422
f576373
 
 
31a6f8f
f688422
 
31a6f8f
 
 
 
 
 
26bb643
c2bf2c8
26bb643
 
 
 
 
f688422
 
 
f576373
 
8bd064f
f576373
 
 
28684d8
f576373
 
28684d8
f576373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from climateqa.engine.talk_to_data.workflow import drias_workflow
from climateqa.engine.llm import get_llm
import ast

llm = get_llm(provider="openai")

def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
    """Adds table names to the SQL query result rows using LLM.
    
    This function modifies the SQL query to include the source table name in each row
    of the result set, making it easier to track which data comes from which table.
    
    Args:
        sql_query (str): The original SQL query to modify
        llm: The language model instance to use for generating the modified query
        
    Returns:
        str: The modified SQL query with table names included in the result rows
    """
    sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
    return sql_with_table_names

def ask_llm_column_names(sql_query: str, llm) -> list[str]:
    """Extracts column names from a SQL query using LLM.
    
    This function analyzes a SQL query to identify which columns are being selected
    in the result set.
    
    Args:
        sql_query (str): The SQL query to analyze
        llm: The language model instance to use for column extraction
        
    Returns:
        list[str]: A list of column names being selected in the query
    """
    columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
    columns_list = ast.literal_eval(columns.strip("```python\n").strip())
    return columns_list

async def ask_drias(query: str, index_state: int = 0) -> tuple:
    """Main function to process a DRIAS query and return results.
    
    This function orchestrates the DRIAS workflow, processing a user query to generate
    SQL queries, dataframes, and visualizations. It handles multiple results and allows
    pagination through them.
    
    Args:
        query (str): The user's question about climate data
        index_state (int, optional): The index of the result to return. Defaults to 0.
        
    Returns:
        tuple: A tuple containing:
            - sql_query (str): The SQL query used
            - dataframe (pd.DataFrame): The resulting data
            - figure (Callable): Function to generate the visualization
            - sql_queries (list): All generated SQL queries
            - result_dataframes (list): All resulting dataframes
            - figures (list): All figure generation functions
            - index_state (int): Current result index
            - table_list (list): List of table names used
            - error (str): Error message if any
    """
    final_state = await drias_workflow(query)
    sql_queries = []
    result_dataframes = []
    figures = []
    table_list = []

    for plot_state in final_state['plot_states'].values():
        for table_state in plot_state['table_states'].values():
            if table_state['status'] == 'OK':
                if 'table_name' in table_state: 
                    table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
                if 'sql_query' in table_state and table_state['sql_query'] is not None:
                    sql_queries.append(table_state['sql_query'])
                
                if 'dataframe' in table_state and table_state['dataframe'] is not None:
                    result_dataframes.append(table_state['dataframe'])
                    if 'figure' in table_state and table_state['figure'] is not None:
                        figures.append(table_state['figure'])

    if "error" in final_state and final_state["error"] != "":
        return None, None, None, [], [], [], 0, final_state["error"]

    sql_query = sql_queries[index_state]
    dataframe = result_dataframes[index_state]
    figure = figures[index_state](dataframe)

    return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""

# def ask_vanna(vn,db_vanna_path, query):

#     try :
#         location = detect_location_with_openai(query)
#         if location:

#             coords = loc2coords(location)
#             user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
            
#             relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
#             coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
#             user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)

            # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)

#             return sql_query, result_dataframe, figure
#         else : 
#             empty_df = pd.DataFrame()
#             empty_fig = None
#             return "", empty_df, empty_fig
#     except Exception as e:
#         print(f"Error: {e}")
#         empty_df = pd.DataFrame()
#         empty_fig = None
#         return "", empty_df, empty_fig