File size: 10,049 Bytes
3ca8396 3653511 26bb643 3ca8396 8bd064f 26bb643 f576373 3653511 f576373 3ca8396 e57556f 3ca8396 8bd064f 3ca8396 8bd064f 3ca8396 e57556f 8bd064f 3653511 d8a656a 3653511 e57556f c2bf2c8 3653511 c2bf2c8 3653511 e57556f 3653511 c2bf2c8 26bb643 3ca8396 f576373 d8a656a 3ca8396 d8a656a 3ca8396 f576373 3ca8396 d8a656a 3ca8396 f576373 3ca8396 26bb643 3ca8396 26bb643 f576373 26bb643 f576373 26bb643 3ca8396 f576373 e57556f d8a656a c2bf2c8 d8a656a c2bf2c8 d8a656a c2bf2c8 f576373 26bb643 f576373 3ca8396 f576373 26bb643 f576373 26bb643 3ca8396 26bb643 f576373 e57556f f576373 3ca8396 f576373 3ca8396 f576373 e57556f f576373 26bb643 f576373 26bb643 e57556f 26bb643 f576373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import re
from typing import Annotated, TypedDict
import duckdb
from geopy.geocoders import Nominatim
import ast
from climateqa.engine.llm import get_llm
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
from langchain_core.prompts import ChatPromptTemplate
async def detect_location_with_openai(sentence):
"""
Detects locations in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = f"""
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
Return the result as a Python list. If no locations are mentioned, return an empty list.
Sentence: "{sentence}"
"""
response = await llm.ainvoke(prompt)
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
if location_list:
return location_list[0]
else:
return ""
class ArrayOutput(TypedDict):
"""Represents the output of a function that returns an array.
This class is used to type-hint functions that return arrays,
ensuring consistent return types across the codebase.
Attributes:
array (str): A syntactically valid Python array string
"""
array: Annotated[str, "Syntactically valid python array."]
async def detect_year_with_openai(sentence: str) -> str:
"""
Detects years in a sentence using OpenAI's API via LangChain.
"""
llm = get_llm()
prompt = """
Extract all years mentioned in the following sentence.
Return the result as a Python list. If no year are mentioned, return an empty list.
Sentence: "{sentence}"
"""
prompt = ChatPromptTemplate.from_template(prompt)
structured_llm = llm.with_structured_output(ArrayOutput)
chain = prompt | structured_llm
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
years_list = eval(response['array'])
if len(years_list) > 0:
return years_list[0]
else:
return ""
def detectTable(sql_query: str) -> list[str]:
"""Extracts table names from a SQL query.
This function uses regular expressions to find all table names
referenced in a SQL query's FROM clause.
Args:
sql_query (str): The SQL query to analyze
Returns:
list[str]: A list of table names found in the query
Example:
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
['temperature_data']
"""
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
matches = re.findall(pattern, sql_query)
return matches
def loc2coords(location: str) -> tuple[float, float]:
"""Converts a location name to geographic coordinates.
This function uses the Nominatim geocoding service to convert
a location name (e.g., city name) to its latitude and longitude.
Args:
location (str): The name of the location to geocode
Returns:
tuple[float, float]: A tuple containing (latitude, longitude)
Raises:
AttributeError: If the location cannot be found
"""
geolocator = Nominatim(user_agent="city_to_latlong")
coords = geolocator.geocode(location)
return (coords.latitude, coords.longitude)
def coords2loc(coords: tuple[float, float]) -> str:
"""Converts geographic coordinates to a location name.
This function uses the Nominatim reverse geocoding service to convert
latitude and longitude coordinates to a human-readable location name.
Args:
coords (tuple[float, float]): A tuple containing (latitude, longitude)
Returns:
str: The address of the location, or "Unknown Location" if not found
Example:
>>> coords2loc((48.8566, 2.3522))
'Paris, France'
"""
geolocator = Nominatim(user_agent="coords_to_city")
try:
location = geolocator.reverse(coords)
return location.address
except Exception as e:
print(f"Error: {e}")
return "Unknown Location"
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
long = round(location[1], 3)
lat = round(location[0], 3)
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
results = duckdb.sql(
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
).fetchdf()
if len(results) == 0:
return "", ""
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
return results['latitude'].iloc[0], results['longitude'].iloc[0]
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
"""Identifies relevant tables for a plot based on user input.
This function uses an LLM to analyze the user's question and the plot
description to determine which tables in the DRIAS database would be
most relevant for generating the requested visualization.
Args:
user_question (str): The user's question about climate data
plot (Plot): The plot configuration object
llm: The language model instance to use for analysis
Returns:
list[str]: A list of table names that are relevant for the plot
Example:
>>> detect_relevant_tables(
... "What will the temperature be like in Paris?",
... indicator_evolution_at_location,
... llm
... )
['mean_annual_temperature', 'mean_summer_temperature']
"""
# Get all table names
table_names_list = DRIAS_TABLES
prompt = (
f"You are helping to build a plot following this description : {plot['description']}."
f"You are given a list of tables and a user question."
f"Based on the description of the plot, which table are appropriate for that kind of plot."
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
f"### List of tables : {table_names_list}"
f"### User question : {user_question}"
f"### List of table name : "
)
table_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return table_names
def replace_coordonates(coords, query, coords_tables):
n = query.count(str(coords[0]))
for i in range(n):
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
return query
async def detect_relevant_plots(user_question: str, llm):
plots_description = ""
for plot in PLOTS:
plots_description += "Name: " + plot["name"]
plots_description += " - Description: " + plot["description"] + "\n"
prompt = (
f"You are helping to answer a quesiton with insightful visualizations."
f"You are given an user question and a list of plots with their name and description."
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
f"Write the most relevant tables to use. Answer only a python list of plot name."
f"### Descriptions of the plots : {plots_description}"
f"### User question : {user_question}"
f"### Name of the plot : "
)
# prompt = (
# f"You are helping to answer a question with insightful visualizations. "
# f"Given a list of plots with their name and description: "
# f"{plots_description} "
# f"The user question is: {user_question}. "
# f"Choose the most relevant plots to answer the question. "
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
# )
plot_names = ast.literal_eval(
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
)
return plot_names
# Next Version
# class QueryOutput(TypedDict):
# """Generated SQL query."""
# query: Annotated[str, ..., "Syntactically valid SQL query."]
# class PlotlyCodeOutput(TypedDict):
# """Generated Plotly code"""
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
# """Generate SQL query to fetch information."""
# prompt_params = {
# "dialect": db.dialect,
# "table_info": db.get_table_info(),
# "input": user_input,
# "relevant_tables": relevant_tables,
# "model": "ALADIN63_CNRM-CM5",
# }
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
# structured_llm = llm.with_structured_output(QueryOutput)
# chain = prompt | structured_llm
# result = chain.invoke(prompt_params)
# return result["query"]
# def fetch_data_from_sql_query(db: str, sql_query: str):
# conn = sqlite3.connect(db)
# cursor = conn.cursor()
# cursor.execute(sql_query)
# column_names = [desc[0] for desc in cursor.description]
# values = cursor.fetchall()
# return {"column_names": column_names, "data": values}
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
# class PlotlyCodeOutput(TypedDict):
# """Generated Plotly code"""
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
# chain = prompt | structured_llm
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
# return result["code"]
|