Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import tempfile | |
import matplotlib.pyplot as plt | |
from pandasai import SmartDataframe | |
from langchain_groq.chat_models import ChatGroq | |
from dotenv import load_dotenv | |
import io | |
import base64 | |
# Load environment variables | |
load_dotenv() | |
# Hardcoded API key - Replace with your actual Groq API key | |
GROQ_API_KEY = "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" | |
def process_data(file, query): | |
""" | |
Process the uploaded CSV file with the given query using PandasAI | |
""" | |
try: | |
# Validate inputs | |
if file is None: | |
return "Please upload a CSV file.", None | |
if not query.strip(): | |
return "Please enter a query.", None | |
# Read the CSV file | |
df_data = pd.read_csv(file.name) | |
# Initialize Groq LLM | |
llm = ChatGroq( | |
model_name="mistral-saba-24b", # Using a more stable model | |
api_key=GROQ_API_KEY, | |
temperature=0 | |
) | |
# Create SmartDataframe | |
df = SmartDataframe(df_data, config={ | |
"llm": llm, | |
"save_charts": True, | |
"save_charts_path": tempfile.gettempdir(), | |
"open_charts": False, | |
"enable_cache": False | |
}) | |
# Process the query | |
result = df.chat(query) | |
# Handle different types of results | |
if result is None: | |
return "No result returned. Please try a different query.", None | |
# Check if result is a plot/chart | |
chart_path = None | |
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
if chart_files: | |
# Get the most recent chart file | |
chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True) | |
chart_path = os.path.join(tempfile.gettempdir(), chart_files[0]) | |
# Format the text result | |
if isinstance(result, pd.DataFrame): | |
result_text = f"Query Result:\n\n{result.to_string()}" | |
elif isinstance(result, (int, float)): | |
result_text = f"Query Result: {result}" | |
elif isinstance(result, str): | |
result_text = f"Query Result:\n{result}" | |
else: | |
result_text = f"Query Result:\n{str(result)}" | |
return result_text, chart_path | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
return error_msg, None | |
def create_interface(): | |
""" | |
Create the Gradio interface | |
""" | |
with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# π PandasAI Data Analysis with Groq | |
Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly. | |
**Instructions:** | |
1. Upload your CSV file | |
2. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region") | |
3. Click Submit to get results | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input components | |
file_input = gr.File( | |
label="Upload CSV File", | |
file_types=[".csv"] | |
) | |
query_input = gr.Textbox( | |
label="Your Query - Ask questions about your data", | |
placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'", | |
lines=3 | |
) | |
submit_btn = gr.Button("π Submit Query", variant="primary") | |
with gr.Column(scale=2): | |
# Output components | |
result_output = gr.Textbox( | |
label="Analysis Result", | |
lines=10, | |
interactive=False | |
) | |
chart_output = gr.Image( | |
label="Generated Visualization" | |
) | |
# Example queries | |
gr.Markdown( | |
""" | |
### π‘ Example Queries: | |
- "Which are the top 5 countries by population?" | |
- "Create a bar plot of the top 10 countries by population" | |
- "Show me a pie chart of the top 5 countries" | |
- "Calculate the total population of the top 3 countries" | |
- "What is the average population across all countries?" | |
- "Create a scatter plot showing the relationship between two columns" | |
""" | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=process_data, | |
inputs=[file_input, query_input], | |
outputs=[result_output, chart_output] | |
) | |
# Allow Enter key to submit | |
query_input.submit( | |
fn=process_data, | |
inputs=[file_input, query_input], | |
outputs=[result_output, chart_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |