import gradio as gr import pandas as pd import os import tempfile import matplotlib.pyplot as plt from pandasai import SmartDataframe from langchain_groq.chat_models import ChatGroq from dotenv import load_dotenv import io import base64 # Load environment variables load_dotenv() def process_data(file, query, api_key): """ Process the uploaded CSV file with the given query using PandasAI """ try: # Validate inputs if file is None: return "Please upload a CSV file.", None if not query.strip(): return "Please enter a query.", None if not api_key.strip(): return "Please enter your Groq API key.", None # Read the CSV file df_data = pd.read_csv(file.name) # Initialize Groq LLM llm = ChatGroq( model_name="mixtral-8x7b-32768", # Using a more stable model api_key=api_key.strip(), temperature=0 ) # Create SmartDataframe df = SmartDataframe(df_data, config={ "llm": llm, "save_charts": True, "save_charts_path": tempfile.gettempdir(), "open_charts": False, "enable_cache": False }) # Process the query result = df.chat(query) # Handle different types of results if result is None: return "No result returned. Please try a different query.", None # Check if result is a plot/chart chart_path = None chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))] if chart_files: # Get the most recent chart file chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True) chart_path = os.path.join(tempfile.gettempdir(), chart_files[0]) # Format the text result if isinstance(result, pd.DataFrame): result_text = f"Query Result:\n\n{result.to_string()}" elif isinstance(result, (int, float)): result_text = f"Query Result: {result}" elif isinstance(result, str): result_text = f"Query Result:\n{result}" else: result_text = f"Query Result:\n{str(result)}" return result_text, chart_path except Exception as e: error_msg = f"Error processing query: {str(e)}" return error_msg, None def create_interface(): """ Create the Gradio interface """ with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 📊 PandasAI Data Analysis with Groq Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly. **Instructions:** 1. Get your Groq API key from [https://console.groq.com/keys](https://console.groq.com/keys) 2. Upload your CSV file 3. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region") 4. Click Submit to get results """ ) with gr.Row(): with gr.Column(scale=1): # Input components api_key_input = gr.Textbox( label="Groq API Key", placeholder="Enter your Groq API key here...", type="password", info="Your API key is not stored and only used for this session" ) file_input = gr.File( label="Upload CSV File", file_types=[".csv"], info="Upload your CSV data file" ) query_input = gr.Textbox( label="Your Query", placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'", lines=3, info="Ask questions about your data or request visualizations" ) submit_btn = gr.Button("🚀 Submit Query", variant="primary") with gr.Column(scale=2): # Output components result_output = gr.Textbox( label="Analysis Result", lines=10, interactive=False, show_copy_button=True ) chart_output = gr.Image( label="Generated Visualization", show_label=True ) # Example queries gr.Markdown( """ ### 💡 Example Queries: - "Which are the top 5 countries by population?" - "Create a bar plot of the top 10 countries by population" - "Show me a pie chart of the top 5 countries" - "Calculate the total population of the top 3 countries" - "What is the average population across all countries?" - "Create a scatter plot showing the relationship between two columns" """ ) # Event handlers submit_btn.click( fn=process_data, inputs=[file_input, query_input, api_key_input], outputs=[result_output, chart_output], show_progress=True ) # Allow Enter key to submit query_input.submit( fn=process_data, inputs=[file_input, query_input, api_key_input], outputs=[result_output, chart_output], show_progress=True ) return demo if __name__ == "__main__": # Create and launch the interface demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, show_tips=True, enable_queue=True )