VayuChat / src.py
Nipun's picture
Fix critical system prompt issues
94a079d
raw
history blame
23.2 kB
import os
import pandas as pd
from typing import Tuple
from PIL import Image
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
import matplotlib.pyplot as plt
import json
from datetime import datetime
from huggingface_hub import HfApi
import uuid
# FORCE reload environment variables
load_dotenv(override=True)
# Get API keys with explicit None handling and debugging
Groq_Token = os.getenv("GROQ_API_KEY")
hf_token = os.getenv("HF_TOKEN")
gemini_token = os.getenv("GEMINI_TOKEN")
# Debug print (remove in production)
print(f"Debug - Groq Token: {'Present' if Groq_Token else 'Missing'}")
print(f"Debug - Groq Token Value: {Groq_Token[:10] + '...' if Groq_Token else 'None'}")
print(f"Debug - Gemini Token: {'Present' if gemini_token else 'Missing'}")
models = {
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"llama3.1": "llama-3.1-8b-instant",
"llama3.3": "llama-3.3-70b-versatile",
"deepseek-R1": "deepseek-r1-distill-llama-70b",
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct",
"llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct",
"gemini-pro": "gemini-1.5-pro"
}
def log_interaction(user_query, model_name, response_content, generated_code, execution_time, error_message=None, is_image=False):
"""Log user interactions to Hugging Face dataset"""
try:
if not hf_token or hf_token.strip() == "":
print("Warning: HF_TOKEN not available, skipping logging")
return
# Create log entry
log_entry = {
"timestamp": datetime.now().isoformat(),
"session_id": str(uuid.uuid4()),
"user_query": user_query,
"model_name": model_name,
"response_content": str(response_content),
"generated_code": generated_code or "",
"execution_time_seconds": execution_time,
"error_message": error_message or "",
"is_image_output": is_image,
"success": error_message is None
}
# Create DataFrame
df = pd.DataFrame([log_entry])
# Create unique filename with timestamp
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
random_id = str(uuid.uuid4())[:8]
filename = f"interaction_log_{timestamp_str}_{random_id}.parquet"
# Save locally first
local_path = f"/tmp/{filename}"
df.to_parquet(local_path, index=False)
# Upload to Hugging Face
api = HfApi(token=hf_token)
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=f"data/{filename}",
repo_id="SustainabilityLabIITGN/VayuChat_logs",
repo_type="dataset",
)
# Clean up local file
if os.path.exists(local_path):
os.remove(local_path)
print(f"Successfully logged interaction to HuggingFace: {filename}")
except Exception as e:
print(f"Error logging interaction: {e}")
def preprocess_and_load_df(path: str) -> pd.DataFrame:
"""Load and preprocess the dataframe"""
try:
df = pd.read_csv(path)
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
return df
except Exception as e:
raise Exception(f"Error loading dataframe: {e}")
def get_from_user(prompt):
"""Format user prompt"""
return {"role": "user", "content": prompt}
def ask_question(model_name, question):
"""Ask question with comprehensive error handling and logging"""
start_time = datetime.now()
try:
# Reload environment variables to get fresh values
load_dotenv(override=True)
fresh_groq_token = os.getenv("GROQ_API_KEY")
fresh_gemini_token = os.getenv("GEMINI_TOKEN")
print(f"ask_question - Fresh Groq Token: {'Present' if fresh_groq_token else 'Missing'}")
# Check API availability with fresh tokens
if model_name == "gemini-pro":
if not fresh_gemini_token or fresh_gemini_token.strip() == "":
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = "Missing or empty API token"
# Log the failed interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content="Gemini API token not available or empty",
generated_code="",
execution_time=execution_time,
error_message=error_msg,
is_image=False
)
return {
"role": "assistant",
"content": "Gemini API token not available or empty. Please set GEMINI_TOKEN in your environment variables.",
"gen_code": "",
"ex_code": "",
"last_prompt": question,
"error": error_msg
}
llm = ChatGoogleGenerativeAI(
model=models[model_name],
google_api_key=fresh_gemini_token,
temperature=0
)
else:
if not fresh_groq_token or fresh_groq_token.strip() == "":
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = "Missing or empty API token"
# Log the failed interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content="Groq API token not available or empty",
generated_code="",
execution_time=execution_time,
error_message=error_msg,
is_image=False
)
return {
"role": "assistant",
"content": "Groq API token not available or empty. Please set GROQ_API_KEY in your environment variables and restart the application.",
"gen_code": "",
"ex_code": "",
"last_prompt": question,
"error": error_msg
}
# Test the API key by trying to create the client
try:
llm = ChatGroq(
model=models[model_name],
api_key=fresh_groq_token,
temperature=0.1
)
# Test with a simple call to verify the API key works
test_response = llm.invoke("Test")
print("API key test successful")
except Exception as api_error:
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = str(api_error)
if "organization_restricted" in error_msg.lower() or "unauthorized" in error_msg.lower():
response_content = "API Key Error: Your Groq API key appears to be invalid, expired, or restricted. Please check your API key in the .env file."
log_error_msg = f"API key validation failed: {error_msg}"
else:
response_content = f"API Connection Error: {error_msg}"
log_error_msg = error_msg
# Log the failed interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content=response_content,
generated_code="",
execution_time=execution_time,
error_message=log_error_msg,
is_image=False
)
return {
"role": "assistant",
"content": response_content,
"gen_code": "",
"ex_code": "",
"last_prompt": question,
"error": log_error_msg
}
# Check if data file exists
if not os.path.exists("Data.csv"):
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = "Data file not found"
# Log the failed interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content="Data.csv file not found",
generated_code="",
execution_time=execution_time,
error_message=error_msg,
is_image=False
)
return {
"role": "assistant",
"content": "Data.csv file not found. Please ensure the data file is in the correct location.",
"gen_code": "",
"ex_code": "",
"last_prompt": question,
"error": error_msg
}
df_check = pd.read_csv("Data.csv")
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
df_check = df_check.head(5)
new_line = "\n"
template = f"""```python
import pandas as pd
import matplotlib.pyplot as plt
import uuid
import calendar
import numpy as np
# Set professional matplotlib styling
plt.rcParams.update({{
'font.size': 12,
'figure.dpi': 400,
'figure.facecolor': 'white',
'axes.facecolor': 'white',
'axes.edgecolor': '#e2e8f0',
'axes.linewidth': 1.2,
'axes.labelcolor': '#374151',
'axes.spines.top': False,
'axes.spines.right': False,
'axes.spines.left': True,
'axes.spines.bottom': True,
'axes.grid': True,
'grid.color': '#f1f5f9',
'grid.linewidth': 0.8,
'grid.alpha': 0.7,
'xtick.color': '#6b7280',
'ytick.color': '#6b7280',
'text.color': '#374151',
'figure.figsize': [12, 6],
'axes.prop_cycle': plt.cycler('color', ['#3b82f6', '#ef4444', '#10b981', '#f59e0b', '#8b5cf6', '#06b6d4'])
}})
df = pd.read_csv("Data.csv")
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
# Available columns and data types:
{new_line.join(map(lambda x: '# '+x, str(df_check.dtypes).split(new_line)))}
# Question: {question.strip()}
# Generate code to answer the question and save result in 'answer' variable
# If creating a plot, save it with a unique filename and store the filename in 'answer'
# If returning text/numbers, store the result directly in 'answer'
```"""
system_prompt = """Generate Python code to answer the user's question about air quality data.
CRITICAL: Only generate Python code - no explanations, no thinking, just clean executable code.
AVAILABLE LIBRARIES:
You can use these pre-installed libraries:
- pandas, numpy (data manipulation)
- matplotlib, seaborn, plotly (visualization)
- statsmodels (statistical modeling, trend analysis)
- scikit-learn (machine learning, regression)
- geopandas (geospatial analysis)
LIBRARY USAGE RULES:
- For trend analysis: Use numpy.polyfit(x, y, 1) for simple linear trends
- For regression: Use sklearn.linear_model.LinearRegression() for robust regression
- For statistical modeling: Use statsmodels only if needed, otherwise use numpy/sklearn
- Always import libraries at the top: import numpy as np, from sklearn.linear_model import LinearRegression
- Handle missing libraries gracefully with try-except around imports
OUTPUT TYPE REQUIREMENTS:
1. PLOT GENERATION (for "plot", "chart", "visualize", "show trend", "graph"):
- MUST create matplotlib figure with proper labels, title, legend
- MUST save plot: filename = f"plot_{uuid.uuid4().hex[:8]}.png"
- MUST call plt.savefig(filename, dpi=300, bbox_inches='tight')
- MUST call plt.close() to prevent memory leaks
- MUST store filename in 'answer' variable: answer = filename
- Handle empty data gracefully before plotting
2. TEXT ANSWERS (for simple "Which", "What", single values):
- Store direct string answer in 'answer' variable
- Example: answer = "December had the highest pollution"
3. DATAFRAMES (for lists, rankings, comparisons, multiple results):
- Create clean DataFrame with descriptive column names
- Sort appropriately for readability
- Store DataFrame in 'answer' variable: answer = result_df
MANDATORY SAFETY & ROBUSTNESS RULES:
DATA VALIDATION (ALWAYS CHECK):
- Check if DataFrame exists and not empty: if df.empty: answer = "No data available"
- Validate required columns exist: if 'PM2.5' not in df.columns: answer = "Required data not available"
- Check for sufficient data: if len(df) < 10: answer = "Insufficient data for analysis"
- Remove invalid/missing values: df = df.dropna(subset=['PM2.5', 'city', 'Timestamp'])
- Use early exit pattern: if condition: answer = "error message"; else: continue with analysis
OPERATION SAFETY (PREVENT CRASHES):
- Wrap risky operations in try-except blocks
- Check denominators before division: if denominator == 0: continue
- Validate indexing bounds: if idx >= len(array): continue
- Check for empty results after filtering: if result_df.empty: answer = "No data found"
- Convert data types explicitly: pd.to_numeric(), .astype(int), .astype(str)
- Handle timezone issues with datetime operations
- NO return statements - this is script context, use if/else logic flow
PLOT GENERATION (MANDATORY FOR PLOTS):
- Check data exists before plotting: if plot_data.empty: answer = "No data to plot"
- Always create new figure: plt.figure(figsize=(12, 8))
- Add comprehensive labels: plt.title(), plt.xlabel(), plt.ylabel()
- Handle long city names: plt.xticks(rotation=45, ha='right')
- Use tight layout: plt.tight_layout()
- CRITICAL PLOT SAVING SEQUENCE (no return statements):
1. filename = f"plot_{uuid.uuid4().hex[:8]}.png"
2. plt.savefig(filename, dpi=300, bbox_inches='tight')
3. plt.close()
4. answer = filename
- Use if/else logic: if data_valid: create_plot(); answer = filename else: answer = "error"
CRITICAL CODING PRACTICES:
DATA VALIDATION & SAFETY:
- Always check if DataFrames/Series are empty before operations: if df.empty: return
- Use .dropna() to handle missing values or .fillna() with appropriate defaults
- Validate column names exist before accessing: if 'column' in df.columns
- Check data types before operations: df['col'].dtype, isinstance() checks
- Handle edge cases: empty results, single row/column DataFrames, all NaN columns
- Use .copy() when modifying DataFrames to avoid SettingWithCopyWarning
VARIABLE & TYPE HANDLING:
- Use descriptive variable names (avoid single letters in complex operations)
- Ensure all variables are defined before use - initialize with defaults
- Convert pandas/numpy objects to proper Python types before operations
- Convert datetime/period objects appropriately: .astype(str), .dt.strftime(), int()
- Always cast to appropriate types for indexing: int(), str(), list()
- CRITICAL: Convert pandas/numpy values to int before list indexing: int(value) for calendar.month_name[int(month_value)]
- Use explicit type conversions rather than relying on implicit casting
PANDAS OPERATIONS:
- Reference DataFrame properly: df['column'] not 'column' in operations
- Use .loc/.iloc correctly for indexing - avoid chained indexing
- Use .reset_index() after groupby operations when needed for clean DataFrames
- Sort results for consistent output: .sort_values(), .sort_index()
- Use .round() for numerical results to avoid excessive decimals
- Chain operations carefully - split complex chains for readability
MATPLOTLIB & PLOTTING:
- Always call plt.close() after saving plots to prevent memory leaks
- Use descriptive titles, axis labels, and legends
- Handle cases where no data exists for plotting
- Use proper figure sizing: plt.figure(figsize=(width, height))
- Convert datetime indices to strings for plotting if needed
- Use color palettes consistently
ERROR PREVENTION:
- Use try-except blocks for operations that might fail
- Check denominators before division operations
- Validate array/list lengths before indexing
- Use .get() method for dictionary access with defaults
- Handle timezone-aware vs naive datetime objects consistently
- Use proper string formatting and encoding for text output
TECHNICAL REQUIREMENTS:
- Save final result in variable called 'answer'
- For TEXT: Store the direct answer as a string in 'answer'
- For PLOTS: Save with unique filename f"plot_{{uuid.uuid4().hex[:8]}}.png" and store filename in 'answer'
- For DATAFRAMES: Store the pandas DataFrame directly in 'answer' (e.g., answer = result_df)
- Always use .iloc or .loc properly for pandas indexing
- Close matplotlib figures with plt.close() to prevent memory leaks
- Use proper column name checks before accessing columns
- For dataframes, ensure proper column names and sorting for readability
"""
query = f"""{system_prompt}
Complete the following code to answer the user's question:
{template}
"""
# Make API call
if model_name == "gemini-pro":
response = llm.invoke(query)
answer = response.content
else:
response = llm.invoke(query)
answer = response.content
# Extract and execute code with enhanced error handling
try:
if "```python" in answer:
code_part = answer.split("```python")[1].split("```")[0]
else:
code_part = answer
full_code = f"""
{template.split("```python")[1].split("```")[0]}
{code_part}
"""
# Execute code in a controlled environment with better error handling
local_vars = {}
global_vars = {
'pd': pd,
'plt': plt,
'os': os,
'uuid': __import__('uuid'),
'calendar': __import__('calendar'),
'np': __import__('numpy')
}
exec(full_code, global_vars, local_vars)
# Get the answer
if 'answer' in local_vars:
answer_result = local_vars['answer']
else:
answer_result = "Code executed but no result was saved in 'answer' variable"
execution_time = (datetime.now() - start_time).total_seconds()
# Determine if output is an image
is_image = isinstance(answer_result, str) and any(answer_result.endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])
# Log successful interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content=str(answer_result),
generated_code=full_code,
execution_time=execution_time,
error_message=None,
is_image=is_image
)
return {
"role": "assistant",
"content": answer_result,
"gen_code": full_code,
"ex_code": full_code,
"last_prompt": question,
"error": None
}
except Exception as code_error:
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = str(code_error)
# Classify and provide user-friendly error messages
user_friendly_msg = "I encountered an error while analyzing your data. "
if "unmatched" in error_msg.lower() or "invalid syntax" in error_msg.lower():
user_friendly_msg += "There was a syntax error in the generated code (missing brackets or quotes). Please try rephrasing your question or try again."
elif "not defined" in error_msg.lower():
user_friendly_msg += "There was a variable naming error in the generated code. Please try asking the question again."
elif "has no attribute" in error_msg.lower():
user_friendly_msg += "There was an issue accessing data properties. Please try a simpler version of your question."
elif "division by zero" in error_msg.lower():
user_friendly_msg += "The calculation involved division by zero, possibly due to missing data. Please try a different time period or location."
elif "empty" in error_msg.lower() or "no data" in error_msg.lower():
user_friendly_msg += "No relevant data was found for your query. Please try adjusting the time period, location, or criteria."
else:
user_friendly_msg += f"Technical error: {error_msg}"
user_friendly_msg += "\n\n💡 **Suggestions:**\n- Try rephrasing your question\n- Use simpler terms\n- Check if the data exists for your specified criteria"
# Log the failed code execution
log_interaction(
user_query=question,
model_name=model_name,
response_content=user_friendly_msg,
generated_code=full_code if 'full_code' in locals() else "",
execution_time=execution_time,
error_message=error_msg,
is_image=False
)
return {
"role": "assistant",
"content": user_friendly_msg,
"gen_code": full_code if 'full_code' in locals() else "",
"ex_code": full_code if 'full_code' in locals() else "",
"last_prompt": question,
"error": error_msg
}
except Exception as e:
execution_time = (datetime.now() - start_time).total_seconds()
error_msg = str(e)
# Handle specific API errors
if "organization_restricted" in error_msg:
response_content = "API Organization Restricted: Your API key access has been restricted. Please check your Groq API key or try generating a new one."
log_error_msg = "API access restricted"
elif "rate_limit" in error_msg.lower():
response_content = "Rate limit exceeded. Please wait a moment and try again."
log_error_msg = "Rate limit exceeded"
else:
response_content = f"Error: {error_msg}"
log_error_msg = error_msg
# Log the failed interaction
log_interaction(
user_query=question,
model_name=model_name,
response_content=response_content,
generated_code="",
execution_time=execution_time,
error_message=log_error_msg,
is_image=False
)
return {
"role": "assistant",
"content": response_content,
"gen_code": "",
"ex_code": "",
"last_prompt": question,
"error": log_error_msg
}