Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('Agg') # Use non-interactive backend | |
import os | |
import tempfile | |
import base64 | |
from io import BytesIO | |
from pandasai import SmartDataframe | |
from langchain_groq.chat_models import ChatGroq | |
# HARDCODED API KEY - REPLACE WITH YOUR ACTUAL KEY | |
API_KEY = "gsk_YOUR_ACTUAL_API_KEY_HERE" # Replace with your real API key | |
# Global variables to store data | |
current_df = None | |
llm = None | |
def initialize_llm(): | |
"""Initialize the Groq LLM""" | |
global llm | |
try: | |
if API_KEY == "gsk_YOUR_ACTUAL_API_KEY_HERE": | |
return "β Please replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual Groq API key", None | |
llm = ChatGroq( | |
model_name="mixtral-8x7b-32768", | |
api_key=API_KEY, | |
temperature=0 | |
) | |
return "β Groq LLM initialized successfully", llm | |
except Exception as e: | |
return f"β Failed to initialize Groq LLM: {str(e)}", None | |
def process_csv(file): | |
"""Process uploaded CSV file""" | |
global current_df | |
if file is None: | |
return "No file uploaded", None, None | |
try: | |
# Read the CSV file | |
current_df = pd.read_csv(file.name) | |
# Create preview | |
preview = current_df.head().to_html(classes='table table-striped', table_id='data-preview') | |
# Create info | |
info = f""" | |
**File Info:** | |
- Shape: {current_df.shape[0]} rows Γ {current_df.shape[1]} columns | |
- Columns: {', '.join(current_df.columns.tolist())} | |
""" | |
return "β CSV file loaded successfully", preview, info | |
except Exception as e: | |
return f"β Error reading CSV: {str(e)}", None, None | |
def chat_with_data(query): | |
"""Process user query and return response""" | |
global current_df, llm | |
if current_df is None: | |
return "β Please upload a CSV file first", None | |
if llm is None: | |
status, _ = initialize_llm() | |
if llm is None: | |
return status, None | |
if not query.strip(): | |
return "β Please enter a query", None | |
try: | |
# Create temporary directory for charts | |
temp_dir = tempfile.mkdtemp() | |
# Create SmartDataframe | |
sdf = SmartDataframe( | |
current_df, | |
config={ | |
"llm": llm, | |
"verbose": True, | |
"save_charts": True, | |
"save_charts_path": temp_dir, | |
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"] | |
} | |
) | |
# Process the query | |
result = sdf.chat(query) | |
# Handle different types of results | |
if isinstance(result, str): | |
# Text response | |
return f"π’ **Response:**\n{result}", None | |
elif hasattr(result, 'savefig'): | |
# Matplotlib figure | |
try: | |
# Save figure to bytes | |
img_buffer = BytesIO() | |
result.savefig(img_buffer, format='png', dpi=150, bbox_inches='tight') | |
img_buffer.seek(0) | |
# Save to temporary file for Gradio | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
with open(temp_file.name, 'wb') as f: | |
f.write(img_buffer.getvalue()) | |
plt.close(result) # Close the figure to free memory | |
return "π **Chart Generated:**", temp_file.name | |
except Exception as chart_error: | |
return f"β Error saving chart: {str(chart_error)}", None | |
elif isinstance(result, pd.DataFrame): | |
# DataFrame result | |
html_table = result.to_html(classes='table table-striped', max_rows=100) | |
return f"π **Data Result:**\n{html_table}", None | |
else: | |
# Other types of results | |
return f"π **Result:**\n{str(result)}", None | |
except Exception as e: | |
error_msg = f"β Error: {str(e)}" | |
# Provide specific error guidance | |
if "403" in str(e): | |
error_msg += "\n\nπ **403 Forbidden Error** - This usually means:\n" | |
error_msg += "- Invalid API key\n" | |
error_msg += "- API key doesn't have permission for this model\n" | |
error_msg += "- Rate limit exceeded\n" | |
error_msg += "- Model name is incorrect" | |
elif "rate limit" in str(e).lower(): | |
error_msg += "\n\nβ° **Rate Limit** - Please wait a moment before trying again" | |
elif "timeout" in str(e).lower(): | |
error_msg += "\n\nβ±οΈ **Timeout** - The query took too long. Try a simpler request" | |
return error_msg, None | |
def get_debug_info(): | |
"""Get debug information""" | |
if API_KEY and API_KEY != "gsk_YOUR_ACTUAL_API_KEY_HERE": | |
return f"β API Key loaded successfully\nKey starts with: {API_KEY[:10]}..." | |
else: | |
return "β Replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual API key" | |
# Initialize LLM on startup | |
init_status, _ = initialize_llm() | |
# Create Gradio interface | |
with gr.Blocks(title="π CSV Chat with Groq + PandasAI", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π Chat with Your CSV using PandasAI + Groq") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# File upload section | |
gr.Markdown("## π Upload CSV File") | |
file_input = gr.File( | |
label="Upload your CSV file", | |
file_types=[".csv"], | |
type="filepath" | |
) | |
upload_status = gr.Textbox( | |
label="Upload Status", | |
interactive=False, | |
value=init_status | |
) | |
# Data preview section | |
gr.Markdown("## π Data Preview") | |
data_preview = gr.HTML(label="Data Preview") | |
data_info = gr.Markdown() | |
with gr.Column(scale=1): | |
# Debug and help section | |
gr.Markdown("## π§ Debug Info") | |
debug_btn = gr.Button("Show Debug Info") | |
debug_info = gr.Textbox(label="Debug Information", interactive=False) | |
gr.Markdown("## π Example Queries") | |
gr.Markdown(""" | |
- "Show me the first 10 rows" | |
- "What are the column names?" | |
- "Create a histogram of [column_name]" | |
- "Show me the summary statistics" | |
- "Plot the top 5 values in [column_name]" | |
- "Create a bar chart showing [column1] vs [column2]" | |
""") | |
# Chat section | |
gr.Markdown("## π¬ Chat with Your Data") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Ask a question or request a chart", | |
placeholder="What would you like to know about your data?", | |
lines=3, | |
scale=4 | |
) | |
submit_btn = gr.Button("Submit Query", variant="primary", scale=1) | |
# Results section | |
with gr.Row(): | |
with gr.Column(): | |
response_output = gr.Markdown(label="Response") | |
with gr.Column(): | |
chart_output = gr.Image(label="Generated Chart", type="filepath") | |
# Event handlers | |
file_input.change( | |
fn=process_csv, | |
inputs=[file_input], | |
outputs=[upload_status, data_preview, data_info] | |
) | |
debug_btn.click( | |
fn=get_debug_info, | |
outputs=[debug_info] | |
) | |
submit_btn.click( | |
fn=chat_with_data, | |
inputs=[query_input], | |
outputs=[response_output, chart_output] | |
) | |
query_input.submit( | |
fn=chat_with_data, | |
inputs=[query_input], | |
outputs=[response_output, chart_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
share=False, # Set to True if you want a public link | |
debug=True, | |
show_error=True | |
) |