Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import tempfile | |
import matplotlib.pyplot as plt | |
from pandasai import SmartDataframe | |
from langchain_groq.chat_models import ChatGroq | |
from dotenv import load_dotenv | |
import io | |
import base64 | |
import re | |
# Load environment variables | |
# load_dotenv() | |
# # Hardcoded API key - Replace with your actual Groq API key | |
# GROQ_API_KEY = "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" | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
# Global variables to store data | |
current_dataframe = None | |
current_smart_df = None | |
last_query_result = None | |
def analyze_chart_feasibility(query, df_data): | |
""" | |
Analyze if the query can generate a meaningful chart | |
""" | |
query_lower = query.lower() | |
# Chart-related keywords | |
chart_keywords = [ | |
'plot', 'chart', 'graph', 'visualize', 'visualization', 'bar', 'line', | |
'pie', 'scatter', 'histogram', 'heatmap', 'boxplot', 'distribution' | |
] | |
# Statistical keywords that might benefit from visualization | |
stat_keywords = [ | |
'top', 'bottom', 'highest', 'lowest', 'compare', 'comparison', | |
'trend', 'relationship', 'correlation', 'by category', 'group by' | |
] | |
# Check if query explicitly asks for a chart | |
explicit_chart = any(keyword in query_lower for keyword in chart_keywords) | |
# Check if query has statistical nature that could be visualized | |
statistical_nature = any(keyword in query_lower for keyword in stat_keywords) | |
# Check data characteristics | |
numeric_columns = df_data.select_dtypes(include=['number']).columns.tolist() | |
categorical_columns = df_data.select_dtypes(include=['object', 'category']).columns.tolist() | |
can_create_chart = False | |
chart_recommendation = "" | |
reasoning = "" | |
if explicit_chart: | |
can_create_chart = True | |
reasoning = "Query explicitly requests a chart/visualization." | |
chart_recommendation = "Chart will be generated as requested." | |
elif statistical_nature and len(numeric_columns) > 0: | |
can_create_chart = True | |
reasoning = f"Query involves statistical analysis with {len(numeric_columns)} numeric columns available for visualization." | |
# Suggest appropriate chart types | |
if 'top' in query_lower or 'bottom' in query_lower: | |
chart_recommendation = "Recommended: Bar chart to show rankings/comparisons." | |
elif 'relationship' in query_lower or 'correlation' in query_lower: | |
chart_recommendation = "Recommended: Scatter plot to show relationships." | |
elif 'distribution' in query_lower: | |
chart_recommendation = "Recommended: Histogram or box plot for distribution analysis." | |
else: | |
chart_recommendation = "Recommended: Bar chart or line chart based on data nature." | |
else: | |
reasoning = "Query appears to be asking for specific values, calculations, or text-based information that doesn't require visualization." | |
chart_recommendation = "Chart generation not recommended for this type of query." | |
return can_create_chart, reasoning, chart_recommendation | |
def process_query_only(file, query): | |
""" | |
Process the query without generating charts | |
""" | |
global current_dataframe, current_smart_df, last_query_result | |
try: | |
# Validate inputs | |
if file is None: | |
return "Please upload a CSV file.", "", "" | |
if not query.strip(): | |
return "Please enter a query.", "", "" | |
# Read the CSV file if not already loaded or if file changed | |
if current_dataframe is None: | |
current_dataframe = pd.read_csv(file.name) | |
# Initialize Groq LLM | |
llm = ChatGroq( | |
model_name="llama-3.3-70b-versatile", | |
api_key=GROQ_API_KEY, | |
temperature=0 | |
) | |
# Create SmartDataframe | |
current_smart_df = SmartDataframe(current_dataframe, config={ | |
"llm": llm, | |
"save_charts": False, # Disabled for query-only mode | |
"enable_cache": False | |
}) | |
# Analyze chart feasibility | |
can_chart, reasoning, recommendation = analyze_chart_feasibility(query, current_dataframe) | |
# Process the query | |
result = current_smart_df.chat(query) | |
last_query_result = result | |
# Handle different types of results | |
if result is None: | |
return "No result returned. Please try a different query.", reasoning, recommendation | |
# Format the text result | |
if isinstance(result, pd.DataFrame): | |
result_text = f"Query Result:\n\n{result.to_string()}" | |
elif isinstance(result, (int, float)): | |
result_text = f"Query Result: {result}" | |
elif isinstance(result, str): | |
result_text = f"Query Result:\n{result}" | |
else: | |
result_text = f"Query Result:\n{str(result)}" | |
return result_text, reasoning, recommendation | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
return error_msg, "", "" | |
def generate_chart(query): | |
""" | |
Generate chart based on the query and last result | |
""" | |
global current_dataframe, current_smart_df, last_query_result | |
try: | |
if current_smart_df is None: | |
return "Please run a query first before generating charts.", None | |
if not query.strip(): | |
return "Please enter a query for chart generation.", None | |
# Clean up old chart files | |
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
for file in chart_files: | |
try: | |
os.remove(os.path.join(tempfile.gettempdir(), file)) | |
except: | |
pass | |
# Create a chart-focused version of the query | |
chart_query = query | |
if not any(keyword in query.lower() for keyword in ['plot', 'chart', 'graph', 'visualize']): | |
# Add visualization instruction to the query | |
chart_query = f"Create a chart or visualization for: {query}" | |
# Reconfigure SmartDataframe for chart generation | |
llm = ChatGroq( | |
model_name="llama-3.3-70b-versatile", | |
api_key=GROQ_API_KEY, | |
temperature=0 | |
) | |
chart_smart_df = SmartDataframe(current_dataframe, config={ | |
"llm": llm, | |
"save_charts": True, | |
"save_charts_path": tempfile.gettempdir(), | |
"open_charts": False, | |
"enable_cache": False | |
}) | |
# Generate chart | |
result = chart_smart_df.chat(chart_query) | |
# Look for generated chart | |
chart_path = None | |
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
if chart_files: | |
# Get the most recent chart file | |
chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True) | |
chart_path = os.path.join(tempfile.gettempdir(), chart_files[0]) | |
return "Chart generated successfully!", chart_path | |
else: | |
return "Chart could not be generated. The query might not be suitable for visualization or there might be an issue with the data.", None | |
except Exception as e: | |
error_msg = f"Error generating chart: {str(e)}" | |
return error_msg, None | |
def reset_data(): | |
""" | |
Reset the loaded data to allow new file upload | |
""" | |
global current_dataframe, current_smart_df, last_query_result | |
current_dataframe = None | |
current_smart_df = None | |
last_query_result = None | |
return "Data reset. Please upload a new file.", "", "", None, None | |
def create_interface(): | |
""" | |
Create the Gradio interface | |
""" | |
with gr.Blocks(title="Enhanced PandasAI with Groq", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# π Enhanced PandasAI Data Analysis with Groq | |
Upload a CSV file and analyze your data with separate query and chart generation capabilities. | |
**Instructions:** | |
1. Upload your CSV file | |
2. Enter your query and click "Analyze Query" to get text results and chart feasibility analysis | |
3. If chart is recommended, click "Generate Chart" to create visualizations | |
4. Use "Reset Data" to load a new file | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input components | |
file_input = gr.File( | |
label="Upload CSV File", | |
file_types=[".csv"] | |
) | |
query_input = gr.Textbox( | |
label="Your Query", | |
placeholder="e.g., 'Which are the top 5 countries by population?' or 'Show relationship between two columns'", | |
lines=3 | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button("π Analyze Query", variant="primary") | |
chart_btn = gr.Button("π Generate Chart", variant="secondary") | |
reset_btn = gr.Button("π Reset Data", variant="stop") | |
with gr.Column(scale=2): | |
# Output components | |
result_output = gr.Textbox( | |
label="Analysis Result", | |
lines=8, | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
feasibility_output = gr.Textbox( | |
label="Chart Feasibility Analysis", | |
lines=3, | |
interactive=False | |
) | |
with gr.Column(): | |
recommendation_output = gr.Textbox( | |
label="Chart Recommendation", | |
lines=3, | |
interactive=False | |
) | |
chart_status = gr.Textbox( | |
label="Chart Generation Status", | |
lines=2, | |
interactive=False | |
) | |
chart_output = gr.Image( | |
label="Generated Visualization" | |
) | |
# Example section | |
gr.Markdown( | |
""" | |
### π‘ Example Workflow: | |
**Step 1 - Data Analysis Queries:** | |
- "What are the top 10 countries by population?" | |
- "Calculate the average population of all countries" | |
- "Which country has the highest GDP?" | |
**Step 2 - Chart Generation:** | |
- After running a query, click "Generate Chart" to visualize the results | |
- The system will analyze if your query can be effectively visualized | |
- Charts work best with comparative, ranking, or relationship-based queries | |
**Query Types that work well for charts:** | |
- Ranking queries (top/bottom N items) | |
- Comparisons between categories | |
- Relationships between variables | |
- Distribution analysis | |
""" | |
) | |
# Event handlers | |
analyze_btn.click( | |
fn=process_query_only, | |
inputs=[file_input, query_input], | |
outputs=[result_output, feasibility_output, recommendation_output] | |
) | |
chart_btn.click( | |
fn=generate_chart, | |
inputs=[query_input], | |
outputs=[chart_status, chart_output] | |
) | |
reset_btn.click( | |
fn=reset_data, | |
outputs=[chart_status, feasibility_output, recommendation_output, chart_output, result_output] | |
) | |
# Allow Enter key to analyze query | |
query_input.submit( | |
fn=process_query_only, | |
inputs=[file_input, query_input], | |
outputs=[result_output, feasibility_output, recommendation_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |