import re from typing import Annotated, TypedDict import duckdb from geopy.geocoders import Nominatim import ast from climateqa.engine.llm import get_llm from climateqa.engine.talk_to_data.config import DRIAS_TABLES from climateqa.engine.talk_to_data.plot import PLOTS, Plot from langchain_core.prompts import ChatPromptTemplate async def detect_location_with_openai(sentence): """ Detects locations in a sentence using OpenAI's API via LangChain. """ llm = get_llm() prompt = f""" Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. Return the result as a Python list. If no locations are mentioned, return an empty list. Sentence: "{sentence}" """ response = await llm.ainvoke(prompt) location_list = ast.literal_eval(response.content.strip("```python\n").strip()) if location_list: return location_list[0] else: return "" class ArrayOutput(TypedDict): """Represents the output of a function that returns an array. This class is used to type-hint functions that return arrays, ensuring consistent return types across the codebase. Attributes: array (str): A syntactically valid Python array string """ array: Annotated[str, "Syntactically valid python array."] async def detect_year_with_openai(sentence: str) -> str: """ Detects years in a sentence using OpenAI's API via LangChain. """ llm = get_llm() prompt = """ Extract all years mentioned in the following sentence. Return the result as a Python list. If no year are mentioned, return an empty list. Sentence: "{sentence}" """ prompt = ChatPromptTemplate.from_template(prompt) structured_llm = llm.with_structured_output(ArrayOutput) chain = prompt | structured_llm response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) years_list = eval(response['array']) if len(years_list) > 0: return years_list[0] else: return "" def detectTable(sql_query: str) -> list[str]: """Extracts table names from a SQL query. This function uses regular expressions to find all table names referenced in a SQL query's FROM clause. Args: sql_query (str): The SQL query to analyze Returns: list[str]: A list of table names found in the query Example: >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000") ['temperature_data'] """ pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)' matches = re.findall(pattern, sql_query) return matches def loc2coords(location: str) -> tuple[float, float]: """Converts a location name to geographic coordinates. This function uses the Nominatim geocoding service to convert a location name (e.g., city name) to its latitude and longitude. Args: location (str): The name of the location to geocode Returns: tuple[float, float]: A tuple containing (latitude, longitude) Raises: AttributeError: If the location cannot be found """ geolocator = Nominatim(user_agent="city_to_latlong") coords = geolocator.geocode(location) return (coords.latitude, coords.longitude) def coords2loc(coords: tuple[float, float]) -> str: """Converts geographic coordinates to a location name. This function uses the Nominatim reverse geocoding service to convert latitude and longitude coordinates to a human-readable location name. Args: coords (tuple[float, float]): A tuple containing (latitude, longitude) Returns: str: The address of the location, or "Unknown Location" if not found Example: >>> coords2loc((48.8566, 2.3522)) 'Paris, France' """ geolocator = Nominatim(user_agent="coords_to_city") try: location = geolocator.reverse(coords) return location.address except Exception as e: print(f"Error: {e}") return "Unknown Location" def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]: long = round(location[1], 3) lat = round(location[0], 3) table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'" results = duckdb.sql( f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" ).fetchdf() if len(results) == 0: return "", "" # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}") return results['latitude'].iloc[0], results['longitude'].iloc[0] async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]: """Identifies relevant tables for a plot based on user input. This function uses an LLM to analyze the user's question and the plot description to determine which tables in the DRIAS database would be most relevant for generating the requested visualization. Args: user_question (str): The user's question about climate data plot (Plot): The plot configuration object llm: The language model instance to use for analysis Returns: list[str]: A list of table names that are relevant for the plot Example: >>> detect_relevant_tables( ... "What will the temperature be like in Paris?", ... indicator_evolution_at_location, ... llm ... ) ['mean_annual_temperature', 'mean_summer_temperature'] """ # Get all table names table_names_list = DRIAS_TABLES prompt = ( f"You are helping to build a plot following this description : {plot['description']}." f"You are given a list of tables and a user question." f"Based on the description of the plot, which table are appropriate for that kind of plot." f"Write the 3 most relevant tables to use. Answer only a python list of table name." f"### List of tables : {table_names_list}" f"### User question : {user_question}" f"### List of table name : " ) table_names = ast.literal_eval( (await llm.ainvoke(prompt)).content.strip("```python\n").strip() ) return table_names def replace_coordonates(coords, query, coords_tables): n = query.count(str(coords[0])) for i in range(n): query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1) query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1) return query async def detect_relevant_plots(user_question: str, llm): plots_description = "" for plot in PLOTS: plots_description += "Name: " + plot["name"] plots_description += " - Description: " + plot["description"] + "\n" prompt = ( f"You are helping to answer a quesiton with insightful visualizations." f"You are given an user question and a list of plots with their name and description." f"Based on the descriptions of the plots, which plot is appropriate to answer to this question." f"Write the most relevant tables to use. Answer only a python list of plot name." f"### Descriptions of the plots : {plots_description}" f"### User question : {user_question}" f"### Name of the plot : " ) # prompt = ( # f"You are helping to answer a question with insightful visualizations. " # f"Given a list of plots with their name and description: " # f"{plots_description} " # f"The user question is: {user_question}. " # f"Choose the most relevant plots to answer the question. " # f"The answer must be a Python list with the names of the relevant plots, and nothing else. " # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']." # ) plot_names = ast.literal_eval( (await llm.ainvoke(prompt)).content.strip("```python\n").strip() ) return plot_names # Next Version # class QueryOutput(TypedDict): # """Generated SQL query.""" # query: Annotated[str, ..., "Syntactically valid SQL query."] # class PlotlyCodeOutput(TypedDict): # """Generated Plotly code""" # code: Annotated[str, ..., "Synatically valid Plotly python code."] # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm): # """Generate SQL query to fetch information.""" # prompt_params = { # "dialect": db.dialect, # "table_info": db.get_table_info(), # "input": user_input, # "relevant_tables": relevant_tables, # "model": "ALADIN63_CNRM-CM5", # } # prompt = ChatPromptTemplate.from_template(query_prompt_template) # structured_llm = llm.with_structured_output(QueryOutput) # chain = prompt | structured_llm # result = chain.invoke(prompt_params) # return result["query"] # def fetch_data_from_sql_query(db: str, sql_query: str): # conn = sqlite3.connect(db) # cursor = conn.cursor() # cursor.execute(sql_query) # column_names = [desc[0] for desc in cursor.description] # values = cursor.fetchall() # return {"column_names": column_names, "data": values} # def generate_chart_code(user_input: str, sql_query: list[str], llm): # """ "Generate plotly python code for the chart based on the sql query and the user question""" # class PlotlyCodeOutput(TypedDict): # """Generated Plotly code""" # code: Annotated[str, ..., "Synatically valid Plotly python code."] # prompt = ChatPromptTemplate.from_template(plot_prompt_template) # structured_llm = llm.with_structured_output(PlotlyCodeOutput) # chain = prompt | structured_llm # result = chain.invoke({"input": user_input, "sql_query": sql_query}) # return result["code"]