pandasai_chart / app.py_mark_1_fully-function_defect_every_time_generating charts.txt
srivatsavdamaraju's picture
Rename app.py to app.py_mark_1_fully-function_defect_every_time_generating charts.txt
1cfe2c6 verified
import gradio as gr
import pandas as pd
import os
import tempfile
import matplotlib.pyplot as plt
from pandasai import SmartDataframe
from langchain_groq.chat_models import ChatGroq
from dotenv import load_dotenv
import io
import base64
# Load environment variables
load_dotenv()
# Hardcoded API key - Replace with your actual Groq API key
GROQ_API_KEY = "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"
def process_data(file, query):
"""
Process the uploaded CSV file with the given query using PandasAI
"""
try:
# Validate inputs
if file is None:
return "Please upload a CSV file.", None
if not query.strip():
return "Please enter a query.", None
# Read the CSV file
df_data = pd.read_csv(file.name)
# Initialize Groq LLM
llm = ChatGroq(
model_name="mistral-saba-24b", # Using a more stable model
api_key=GROQ_API_KEY,
temperature=0
)
# Create SmartDataframe
df = SmartDataframe(df_data, config={
"llm": llm,
"save_charts": True,
"save_charts_path": tempfile.gettempdir(),
"open_charts": False,
"enable_cache": False
})
# Process the query
result = df.chat(query)
# Handle different types of results
if result is None:
return "No result returned. Please try a different query.", None
# Check if result is a plot/chart
chart_path = None
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))]
if chart_files:
# Get the most recent chart file
chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True)
chart_path = os.path.join(tempfile.gettempdir(), chart_files[0])
# Format the text result
if isinstance(result, pd.DataFrame):
result_text = f"Query Result:\n\n{result.to_string()}"
elif isinstance(result, (int, float)):
result_text = f"Query Result: {result}"
elif isinstance(result, str):
result_text = f"Query Result:\n{result}"
else:
result_text = f"Query Result:\n{str(result)}"
return result_text, chart_path
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
return error_msg, None
def create_interface():
"""
Create the Gradio interface
"""
with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ“Š PandasAI Data Analysis with Groq
Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly.
**Instructions:**
1. Upload your CSV file
2. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region")
3. Click Submit to get results
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input components
file_input = gr.File(
label="Upload CSV File",
file_types=[".csv"]
)
query_input = gr.Textbox(
label="Your Query - Ask questions about your data",
placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'",
lines=3
)
submit_btn = gr.Button("πŸš€ Submit Query", variant="primary")
with gr.Column(scale=2):
# Output components
result_output = gr.Textbox(
label="Analysis Result",
lines=10,
interactive=False
)
chart_output = gr.Image(
label="Generated Visualization"
)
# Example queries
gr.Markdown(
"""
### πŸ’‘ Example Queries:
- "Which are the top 5 countries by population?"
- "Create a bar plot of the top 10 countries by population"
- "Show me a pie chart of the top 5 countries"
- "Calculate the total population of the top 3 countries"
- "What is the average population across all countries?"
- "Create a scatter plot showing the relationship between two columns"
"""
)
# Event handlers
submit_btn.click(
fn=process_data,
inputs=[file_input, query_input],
outputs=[result_output, chart_output]
)
# Allow Enter key to submit
query_input.submit(
fn=process_data,
inputs=[file_input, query_input],
outputs=[result_output, chart_output]
)
return demo
if __name__ == "__main__":
# Create and launch the interface
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)