|
import gradio as gr |
|
import pandas as pd |
|
import os |
|
import tempfile |
|
import matplotlib.pyplot as plt |
|
from pandasai import SmartDataframe |
|
from langchain_groq.chat_models import ChatGroq |
|
from dotenv import load_dotenv |
|
import io |
|
import base64 |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
GROQ_API_KEY = "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" |
|
|
|
def process_data(file, query): |
|
""" |
|
Process the uploaded CSV file with the given query using PandasAI |
|
""" |
|
try: |
|
|
|
if file is None: |
|
return "Please upload a CSV file.", None |
|
|
|
if not query.strip(): |
|
return "Please enter a query.", None |
|
|
|
|
|
df_data = pd.read_csv(file.name) |
|
|
|
|
|
llm = ChatGroq( |
|
model_name="mixtral-8x7b-32768", |
|
api_key=GROQ_API_KEY, |
|
temperature=0 |
|
) |
|
|
|
|
|
df = SmartDataframe(df_data, config={ |
|
"llm": llm, |
|
"save_charts": True, |
|
"save_charts_path": tempfile.gettempdir(), |
|
"open_charts": False, |
|
"enable_cache": False |
|
}) |
|
|
|
|
|
result = df.chat(query) |
|
|
|
|
|
if result is None: |
|
return "No result returned. Please try a different query.", None |
|
|
|
|
|
chart_path = None |
|
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
|
|
if chart_files: |
|
|
|
chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True) |
|
chart_path = os.path.join(tempfile.gettempdir(), chart_files[0]) |
|
|
|
|
|
if isinstance(result, pd.DataFrame): |
|
result_text = f"Query Result:\n\n{result.to_string()}" |
|
elif isinstance(result, (int, float)): |
|
result_text = f"Query Result: {result}" |
|
elif isinstance(result, str): |
|
result_text = f"Query Result:\n{result}" |
|
else: |
|
result_text = f"Query Result:\n{str(result)}" |
|
|
|
return result_text, chart_path |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing query: {str(e)}" |
|
return error_msg, None |
|
|
|
def create_interface(): |
|
""" |
|
Create the Gradio interface |
|
""" |
|
with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# π PandasAI Data Analysis with Groq |
|
|
|
Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly. |
|
|
|
**Instructions:** |
|
1. Upload your CSV file |
|
2. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region") |
|
3. Click Submit to get results |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
file_input = gr.File( |
|
label="Upload CSV File", |
|
file_types=[".csv"] |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="Your Query - Ask questions about your data", |
|
placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'", |
|
lines=3 |
|
) |
|
|
|
submit_btn = gr.Button("π Submit Query", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
|
|
result_output = gr.Textbox( |
|
label="Analysis Result", |
|
lines=10, |
|
interactive=False |
|
) |
|
|
|
chart_output = gr.Image( |
|
label="Generated Visualization" |
|
) |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
### π‘ Example Queries: |
|
- "Which are the top 5 countries by population?" |
|
- "Create a bar plot of the top 10 countries by population" |
|
- "Show me a pie chart of the top 5 countries" |
|
- "Calculate the total population of the top 3 countries" |
|
- "What is the average population across all countries?" |
|
- "Create a scatter plot showing the relationship between two columns" |
|
""" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_data, |
|
inputs=[file_input, query_input], |
|
outputs=[result_output, chart_output] |
|
) |
|
|
|
|
|
query_input.submit( |
|
fn=process_data, |
|
inputs=[file_input, query_input], |
|
outputs=[result_output, chart_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |