|
import os |
|
import pandas as pd |
|
from typing import Tuple |
|
from PIL import Image |
|
from dotenv import load_dotenv |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
import matplotlib.pyplot as plt |
|
import json |
|
from datetime import datetime |
|
from huggingface_hub import HfApi |
|
import uuid |
|
|
|
|
|
load_dotenv(override=True) |
|
|
|
|
|
Groq_Token = os.getenv("GROQ_API_KEY") |
|
hf_token = os.getenv("HF_TOKEN") |
|
gemini_token = os.getenv("GEMINI_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
models = { |
|
"gpt-oss-120b": "openai/gpt-oss-120b", |
|
"qwen3-32b": "qwen/qwen3-32b", |
|
"gpt-oss-20b": "openai/gpt-oss-20b", |
|
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct", |
|
"llama3.3": "llama-3.3-70b-versatile", |
|
"deepseek-R1": "deepseek-r1-distill-llama-70b", |
|
"gemini-2.5-flash": "gemini-2.5-flash", |
|
"gemini-2.5-pro": "gemini-2.5-pro", |
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite", |
|
"gemini-2.0-flash": "gemini-2.0-flash", |
|
"gemini-2.0-flash-lite": "gemini-2.0-flash-lite", |
|
|
|
|
|
} |
|
|
|
def log_interaction(user_query, model_name, response_content, generated_code, execution_time, error_message=None, is_image=False): |
|
"""Log user interactions to Hugging Face dataset""" |
|
try: |
|
if not hf_token or hf_token.strip() == "": |
|
print("Warning: HF_TOKEN not available, skipping logging") |
|
return |
|
|
|
|
|
log_entry = { |
|
"timestamp": datetime.now().isoformat(), |
|
"session_id": str(uuid.uuid4()), |
|
"user_query": user_query, |
|
"model_name": model_name, |
|
"response_content": str(response_content), |
|
"generated_code": generated_code or "", |
|
"execution_time_seconds": execution_time, |
|
"error_message": error_message or "", |
|
"is_image_output": is_image, |
|
"success": error_message is None |
|
} |
|
|
|
|
|
df = pd.DataFrame([log_entry]) |
|
|
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
random_id = str(uuid.uuid4())[:8] |
|
filename = f"interaction_log_{timestamp_str}_{random_id}.parquet" |
|
|
|
|
|
local_path = f"/tmp/{filename}" |
|
df.to_parquet(local_path, index=False) |
|
|
|
|
|
api = HfApi(token=hf_token) |
|
api.upload_file( |
|
path_or_fileobj=local_path, |
|
path_in_repo=f"data/{filename}", |
|
repo_id="SustainabilityLabIITGN/VayuChat_logs", |
|
repo_type="dataset", |
|
) |
|
|
|
|
|
if os.path.exists(local_path): |
|
os.remove(local_path) |
|
|
|
print(f"Successfully logged interaction to HuggingFace: {filename}") |
|
|
|
except Exception as e: |
|
print(f"Error logging interaction: {e}") |
|
|
|
def preprocess_and_load_df(path: str) -> pd.DataFrame: |
|
"""Load and preprocess the dataframe""" |
|
try: |
|
df = pd.read_csv(path) |
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"]) |
|
return df |
|
except Exception as e: |
|
raise Exception(f"Error loading dataframe: {e}") |
|
|
|
|
|
def get_from_user(prompt): |
|
"""Format user prompt""" |
|
return {"role": "user", "content": prompt} |
|
|
|
|
|
def ask_question(model_name, question): |
|
"""Ask question with comprehensive error handling and logging""" |
|
start_time = datetime.now() |
|
|
|
|
|
|
|
def make_error_response(msg, log_msg, content=None): |
|
"""Build error response + log it""" |
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
log_interaction( |
|
user_query=question, |
|
model_name=model_name, |
|
response_content=content or msg, |
|
generated_code="", |
|
execution_time=execution_time, |
|
error_message=log_msg, |
|
is_image=False |
|
) |
|
return { |
|
"role": "assistant", |
|
"content": content or msg, |
|
"gen_code": "", |
|
"ex_code": "", |
|
"last_prompt": question, |
|
"error": log_msg |
|
} |
|
def validate_api_token(token, token_name, msg_if_missing): |
|
"""Check for missing/empty API tokens""" |
|
if not token or token.strip() == "": |
|
return make_error_response( |
|
msg="Missing or empty API token", |
|
log_msg="Missing or empty API token", |
|
content=msg_if_missing |
|
) |
|
return None |
|
def run_safe_exec(full_code, df=None, extra_globals=None): |
|
"""Safely execute generated code and handle errors""" |
|
local_vars = {} |
|
global_vars = { |
|
'pd': pd, 'plt': plt, 'os': os, |
|
'sns': __import__('seaborn'), |
|
'uuid': __import__('uuid'), |
|
'calendar': __import__('calendar'), |
|
'np': __import__('numpy'), |
|
'df': df |
|
} |
|
|
|
|
|
if extra_globals: |
|
global_vars.update(extra_globals) |
|
|
|
try: |
|
exec(full_code, global_vars, local_vars) |
|
return ( |
|
local_vars.get('answer', "Code executed but no result was saved in 'answer' variable"), |
|
None |
|
) |
|
except Exception as code_error: |
|
return None, str(code_error) |
|
|
|
|
|
|
|
|
|
load_dotenv(override=True) |
|
fresh_groq_token = os.getenv("GROQ_API_KEY") |
|
fresh_gemini_token = os.getenv("GEMINI_TOKEN") |
|
|
|
|
|
|
|
try: |
|
if "gemini" in model_name: |
|
token_error = validate_api_token( |
|
fresh_gemini_token, |
|
"GEMINI_TOKEN", |
|
"Gemini API token not available or empty. Please set GEMINI_TOKEN in your environment variable." |
|
) |
|
if token_error: |
|
return token_error |
|
|
|
try: |
|
llm = ChatGoogleGenerativeAI( |
|
model=models[model_name], |
|
google_api_key=fresh_gemini_token, |
|
temperature=0 |
|
) |
|
|
|
llm.invoke("Test") |
|
|
|
except Exception as api_error: |
|
return make_error_response( |
|
msg="API Connection Error", |
|
log_msg=str(api_error), |
|
content="API Key Error: Your Gemini API key appears to be invalid, expired, or restricted. Please check your GEMINI_TOKEN in the .env file." |
|
if "organization_restricted"in str(api_error).lower() or "unauthorized" in str(api_error).lower() |
|
else f"API Connection Error: {api_error}" |
|
) |
|
|
|
else: |
|
token_error = validate_api_token( |
|
fresh_groq_token, |
|
"GROQ_API_KEY", |
|
"Groq API token not available or empty. Please set GROQ_API_KEY in your environment variables and restart the application." |
|
) |
|
if token_error: |
|
return token_error |
|
|
|
try: |
|
llm = ChatGroq( |
|
model=models[model_name], |
|
api_key=fresh_groq_token, |
|
temperature=0 |
|
) |
|
llm.invoke("Test") |
|
|
|
except Exception as api_error: |
|
return make_error_response( |
|
msg="API Connection Error", |
|
log_msg=str(api_error), |
|
content="API Key Error: Your Groq API key appears to be invalid, expired, or restricted. Please check your GROQ_API_KEY in the .env file." |
|
if "organization_restricted"in str(api_error).lower() or "unauthorized" in str(api_error).lower() |
|
else f"API Connection Error: {api_error}" |
|
) |
|
except Exception as e: |
|
return make_error_response(str(e), str(e)) |
|
|
|
|
|
|
|
if not os.path.exists("AQ_met_data.csv"): |
|
return make_error_response( |
|
msg="Data file not found", |
|
log_msg="Data file not found", |
|
content="AQ_met_data.csv file not found. Please ensure the data file is in the correct location." |
|
) |
|
|
|
df = pd.read_csv("AQ_met_data.csv") |
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"]) |
|
new_line = "\n" |
|
states_df = pd.read_csv("states_data.csv") |
|
ncap_df = pd.read_csv("ncap_funding_data.csv") |
|
|
|
|
|
template = f"""```python |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import uuid |
|
import calendar |
|
import numpy as np |
|
# Set professional matplotlib styling |
|
plt.style.use('vayuchat.mplstyle') |
|
df = pd.read_csv("AQ_met_data.csv") |
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"]) |
|
states_df = pd.read_csv("states_data.csv") |
|
ncap_df = pd.read_csv("ncap_funding_data.csv") |
|
# df is pandas DataFrame with air quality data from India. Data frequency is daily from 2017 to 2024. The data has the following columns and data types: |
|
{new_line.join(map(lambda x: '# '+x, str(df.dtypes).split(new_line)))} |
|
# states_df is a pandas DataFrame of state-wise population, area and whether state is union territory or not of India. |
|
{new_line.join(map(lambda x: '# '+x, str(states_df.dtypes).split(new_line)))} |
|
# ncap_df is a pandas DataFrame of funding given to the cities of India from 2019-2022, under The National Clean Air Program (NCAP). |
|
{new_line.join(map(lambda x: '# '+x, str(ncap_df.dtypes).split(new_line)))} |
|
# Question: {question.strip()} |
|
# Generate code to answer the question and save result in 'answer' variable |
|
# If creating a plot, save it with a unique filename and store the filename in 'answer' |
|
# If returning text/numbers, store the result directly in 'answer' |
|
```""" |
|
|
|
|
|
with open("new_system_prompt.txt", "r", encoding="utf-8") as f: |
|
system_prompt = f.read().strip() |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": system_prompt |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"""Complete the following code to answer the user's question: |
|
{template}""" |
|
} |
|
] |
|
|
|
|
|
|
|
|
|
try: |
|
response = llm.invoke(messages) |
|
answer = response.content |
|
except Exception as e: |
|
return make_error_response(f"Error: {e}", str(e)) |
|
|
|
|
|
|
|
|
|
code_part = answer.split("```python")[1].split("```")[0] if "```python" in answer else answer |
|
full_code = f""" |
|
{template.split("```python")[1].split("```")[0]} |
|
{code_part} |
|
""" |
|
answer_result, code_error = run_safe_exec(full_code, df, extra_globals={'states_df': states_df, 'ncap_df': ncap_df}) |
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
if code_error: |
|
|
|
msg = "I encountered an error while analyzing your data. " |
|
if "syntax" in code_error.lower(): |
|
msg += "There was a syntax error in the generated code. Please try rephrasing your question." |
|
elif "not defined" in code_error.lower(): |
|
msg += "Variable naming error occurred. Please try asking the question again." |
|
elif "division by zero" in code_error.lower(): |
|
msg += "Calculation involved division by zero, possibly due to missing data." |
|
elif "no data" in code_error.lower() or "empty" in code_error.lower(): |
|
msg += "No relevant data was found for your query." |
|
else: |
|
msg += f"Technical error: {code_error}" |
|
|
|
msg += "\n\n💡 **Suggestions:**\n- Try rephrasing your question\n- Use simpler terms\n- Check if the data exists for your specified criteria" |
|
|
|
log_interaction( |
|
user_query=question, |
|
model_name=model_name, |
|
response_content=msg, |
|
generated_code=full_code, |
|
execution_time=execution_time, |
|
error_message=code_error, |
|
is_image=False |
|
) |
|
return { |
|
"role": "assistant", |
|
"content": msg, |
|
"gen_code": full_code, |
|
"ex_code": full_code, |
|
"last_prompt": question, |
|
"error": code_error |
|
} |
|
|
|
|
|
|
|
|
|
is_image = isinstance(answer_result, str) and answer_result.endswith(('.png', '.jpg', '.jpeg')) |
|
log_interaction( |
|
user_query=question, |
|
model_name=model_name, |
|
response_content=str(answer_result), |
|
generated_code=full_code, |
|
execution_time=execution_time, |
|
error_message=None, |
|
is_image=is_image |
|
) |
|
|
|
return { |
|
"role": "assistant", |
|
"content": answer_result, |
|
"gen_code": full_code, |
|
"ex_code": full_code, |
|
"last_prompt": question, |
|
"error": None |
|
} |