pandasai_chart / app.py
srivatsavdamaraju's picture
Update app.py
2e3c703 verified
raw
history blame
8.15 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import os
import tempfile
import base64
from io import BytesIO
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
# HARDCODED API KEY - REPLACE WITH YOUR ACTUAL KEY
API_KEY = "gsk_YOUR_ACTUAL_API_KEY_HERE" # Replace with your real API key
# Global variables to store data
current_df = None
llm = None
def initialize_llm():
"""Initialize the Groq LLM"""
global llm
try:
if API_KEY == "gsk_YOUR_ACTUAL_API_KEY_HERE":
return "❌ Please replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual Groq API key", None
llm = ChatGroq(
model_name="mixtral-8x7b-32768",
api_key=API_KEY,
temperature=0
)
return "βœ… Groq LLM initialized successfully", llm
except Exception as e:
return f"❌ Failed to initialize Groq LLM: {str(e)}", None
def process_csv(file):
"""Process uploaded CSV file"""
global current_df
if file is None:
return "No file uploaded", None, None
try:
# Read the CSV file
current_df = pd.read_csv(file.name)
# Create preview
preview = current_df.head().to_html(classes='table table-striped', table_id='data-preview')
# Create info
info = f"""
**File Info:**
- Shape: {current_df.shape[0]} rows Γ— {current_df.shape[1]} columns
- Columns: {', '.join(current_df.columns.tolist())}
"""
return "βœ… CSV file loaded successfully", preview, info
except Exception as e:
return f"❌ Error reading CSV: {str(e)}", None, None
def chat_with_data(query):
"""Process user query and return response"""
global current_df, llm
if current_df is None:
return "❌ Please upload a CSV file first", None
if llm is None:
status, _ = initialize_llm()
if llm is None:
return status, None
if not query.strip():
return "❌ Please enter a query", None
try:
# Create temporary directory for charts
temp_dir = tempfile.mkdtemp()
# Create SmartDataframe
sdf = SmartDataframe(
current_df,
config={
"llm": llm,
"verbose": True,
"save_charts": True,
"save_charts_path": temp_dir,
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"]
}
)
# Process the query
result = sdf.chat(query)
# Handle different types of results
if isinstance(result, str):
# Text response
return f"πŸ“’ **Response:**\n{result}", None
elif hasattr(result, 'savefig'):
# Matplotlib figure
try:
# Save figure to bytes
img_buffer = BytesIO()
result.savefig(img_buffer, format='png', dpi=150, bbox_inches='tight')
img_buffer.seek(0)
# Save to temporary file for Gradio
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
f.write(img_buffer.getvalue())
plt.close(result) # Close the figure to free memory
return "πŸ“ˆ **Chart Generated:**", temp_file.name
except Exception as chart_error:
return f"❌ Error saving chart: {str(chart_error)}", None
elif isinstance(result, pd.DataFrame):
# DataFrame result
html_table = result.to_html(classes='table table-striped', max_rows=100)
return f"πŸ“Š **Data Result:**\n{html_table}", None
else:
# Other types of results
return f"πŸ“Š **Result:**\n{str(result)}", None
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
# Provide specific error guidance
if "403" in str(e):
error_msg += "\n\nπŸ” **403 Forbidden Error** - This usually means:\n"
error_msg += "- Invalid API key\n"
error_msg += "- API key doesn't have permission for this model\n"
error_msg += "- Rate limit exceeded\n"
error_msg += "- Model name is incorrect"
elif "rate limit" in str(e).lower():
error_msg += "\n\n⏰ **Rate Limit** - Please wait a moment before trying again"
elif "timeout" in str(e).lower():
error_msg += "\n\n⏱️ **Timeout** - The query took too long. Try a simpler request"
return error_msg, None
def get_debug_info():
"""Get debug information"""
if API_KEY and API_KEY != "gsk_YOUR_ACTUAL_API_KEY_HERE":
return f"βœ… API Key loaded successfully\nKey starts with: {API_KEY[:10]}..."
else:
return "❌ Replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual API key"
# Initialize LLM on startup
init_status, _ = initialize_llm()
# Create Gradio interface
with gr.Blocks(title="πŸ“Š CSV Chat with Groq + PandasAI", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“Š Chat with Your CSV using PandasAI + Groq")
with gr.Row():
with gr.Column(scale=2):
# File upload section
gr.Markdown("## πŸ“ Upload CSV File")
file_input = gr.File(
label="Upload your CSV file",
file_types=[".csv"],
type="filepath"
)
upload_status = gr.Textbox(
label="Upload Status",
interactive=False,
value=init_status
)
# Data preview section
gr.Markdown("## πŸ“‹ Data Preview")
data_preview = gr.HTML(label="Data Preview")
data_info = gr.Markdown()
with gr.Column(scale=1):
# Debug and help section
gr.Markdown("## πŸ”§ Debug Info")
debug_btn = gr.Button("Show Debug Info")
debug_info = gr.Textbox(label="Debug Information", interactive=False)
gr.Markdown("## πŸ“ Example Queries")
gr.Markdown("""
- "Show me the first 10 rows"
- "What are the column names?"
- "Create a histogram of [column_name]"
- "Show me the summary statistics"
- "Plot the top 5 values in [column_name]"
- "Create a bar chart showing [column1] vs [column2]"
""")
# Chat section
gr.Markdown("## πŸ’¬ Chat with Your Data")
with gr.Row():
query_input = gr.Textbox(
label="Ask a question or request a chart",
placeholder="What would you like to know about your data?",
lines=3,
scale=4
)
submit_btn = gr.Button("Submit Query", variant="primary", scale=1)
# Results section
with gr.Row():
with gr.Column():
response_output = gr.Markdown(label="Response")
with gr.Column():
chart_output = gr.Image(label="Generated Chart", type="filepath")
# Event handlers
file_input.change(
fn=process_csv,
inputs=[file_input],
outputs=[upload_status, data_preview, data_info]
)
debug_btn.click(
fn=get_debug_info,
outputs=[debug_info]
)
submit_btn.click(
fn=chat_with_data,
inputs=[query_input],
outputs=[response_output, chart_output]
)
query_input.submit(
fn=chat_with_data,
inputs=[query_input],
outputs=[response_output, chart_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
share=False, # Set to True if you want a public link
debug=True,
show_error=True
)