Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
import base64
|
9 |
+
from io import BytesIO
|
10 |
+
from pandasai import SmartDataframe
|
11 |
+
from langchain_groq.chat_models import ChatGroq
|
12 |
+
|
13 |
+
# HARDCODED API KEY - REPLACE WITH YOUR ACTUAL KEY
|
14 |
+
API_KEY = "gsk_YOUR_ACTUAL_API_KEY_HERE" # Replace with your real API key
|
15 |
+
|
16 |
+
# Global variables to store data
|
17 |
+
current_df = None
|
18 |
+
llm = None
|
19 |
+
|
20 |
+
def initialize_llm():
|
21 |
+
"""Initialize the Groq LLM"""
|
22 |
+
global llm
|
23 |
+
try:
|
24 |
+
if API_KEY == "gsk_YOUR_ACTUAL_API_KEY_HERE":
|
25 |
+
return "β Please replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual Groq API key", None
|
26 |
+
|
27 |
+
llm = ChatGroq(
|
28 |
+
model_name="mixtral-8x7b-32768",
|
29 |
+
api_key=API_KEY,
|
30 |
+
temperature=0
|
31 |
+
)
|
32 |
+
return "β
Groq LLM initialized successfully", llm
|
33 |
+
except Exception as e:
|
34 |
+
return f"β Failed to initialize Groq LLM: {str(e)}", None
|
35 |
+
|
36 |
+
def process_csv(file):
|
37 |
+
"""Process uploaded CSV file"""
|
38 |
+
global current_df
|
39 |
+
|
40 |
+
if file is None:
|
41 |
+
return "No file uploaded", None, None
|
42 |
+
|
43 |
+
try:
|
44 |
+
# Read the CSV file
|
45 |
+
current_df = pd.read_csv(file.name)
|
46 |
+
|
47 |
+
# Create preview
|
48 |
+
preview = current_df.head().to_html(classes='table table-striped', table_id='data-preview')
|
49 |
+
|
50 |
+
# Create info
|
51 |
+
info = f"""
|
52 |
+
**File Info:**
|
53 |
+
- Shape: {current_df.shape[0]} rows Γ {current_df.shape[1]} columns
|
54 |
+
- Columns: {', '.join(current_df.columns.tolist())}
|
55 |
+
"""
|
56 |
+
|
57 |
+
return "β
CSV file loaded successfully", preview, info
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
return f"β Error reading CSV: {str(e)}", None, None
|
61 |
+
|
62 |
+
def chat_with_data(query):
|
63 |
+
"""Process user query and return response"""
|
64 |
+
global current_df, llm
|
65 |
+
|
66 |
+
if current_df is None:
|
67 |
+
return "β Please upload a CSV file first", None
|
68 |
+
|
69 |
+
if llm is None:
|
70 |
+
status, _ = initialize_llm()
|
71 |
+
if llm is None:
|
72 |
+
return status, None
|
73 |
+
|
74 |
+
if not query.strip():
|
75 |
+
return "β Please enter a query", None
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Create temporary directory for charts
|
79 |
+
temp_dir = tempfile.mkdtemp()
|
80 |
+
|
81 |
+
# Create SmartDataframe
|
82 |
+
sdf = SmartDataframe(
|
83 |
+
current_df,
|
84 |
+
config={
|
85 |
+
"llm": llm,
|
86 |
+
"verbose": True,
|
87 |
+
"save_charts": True,
|
88 |
+
"save_charts_path": temp_dir,
|
89 |
+
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"]
|
90 |
+
}
|
91 |
+
)
|
92 |
+
|
93 |
+
# Process the query
|
94 |
+
result = sdf.chat(query)
|
95 |
+
|
96 |
+
# Handle different types of results
|
97 |
+
if isinstance(result, str):
|
98 |
+
# Text response
|
99 |
+
return f"π’ **Response:**\n{result}", None
|
100 |
+
|
101 |
+
elif hasattr(result, 'savefig'):
|
102 |
+
# Matplotlib figure
|
103 |
+
try:
|
104 |
+
# Save figure to bytes
|
105 |
+
img_buffer = BytesIO()
|
106 |
+
result.savefig(img_buffer, format='png', dpi=150, bbox_inches='tight')
|
107 |
+
img_buffer.seek(0)
|
108 |
+
|
109 |
+
# Save to temporary file for Gradio
|
110 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
111 |
+
with open(temp_file.name, 'wb') as f:
|
112 |
+
f.write(img_buffer.getvalue())
|
113 |
+
|
114 |
+
plt.close(result) # Close the figure to free memory
|
115 |
+
|
116 |
+
return "π **Chart Generated:**", temp_file.name
|
117 |
+
|
118 |
+
except Exception as chart_error:
|
119 |
+
return f"β Error saving chart: {str(chart_error)}", None
|
120 |
+
|
121 |
+
elif isinstance(result, pd.DataFrame):
|
122 |
+
# DataFrame result
|
123 |
+
html_table = result.to_html(classes='table table-striped', max_rows=100)
|
124 |
+
return f"π **Data Result:**\n{html_table}", None
|
125 |
+
|
126 |
+
else:
|
127 |
+
# Other types of results
|
128 |
+
return f"π **Result:**\n{str(result)}", None
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
error_msg = f"β Error: {str(e)}"
|
132 |
+
|
133 |
+
# Provide specific error guidance
|
134 |
+
if "403" in str(e):
|
135 |
+
error_msg += "\n\nπ **403 Forbidden Error** - This usually means:\n"
|
136 |
+
error_msg += "- Invalid API key\n"
|
137 |
+
error_msg += "- API key doesn't have permission for this model\n"
|
138 |
+
error_msg += "- Rate limit exceeded\n"
|
139 |
+
error_msg += "- Model name is incorrect"
|
140 |
+
elif "rate limit" in str(e).lower():
|
141 |
+
error_msg += "\n\nβ° **Rate Limit** - Please wait a moment before trying again"
|
142 |
+
elif "timeout" in str(e).lower():
|
143 |
+
error_msg += "\n\nβ±οΈ **Timeout** - The query took too long. Try a simpler request"
|
144 |
+
|
145 |
+
return error_msg, None
|
146 |
+
|
147 |
+
def get_debug_info():
|
148 |
+
"""Get debug information"""
|
149 |
+
if API_KEY and API_KEY != "gsk_YOUR_ACTUAL_API_KEY_HERE":
|
150 |
+
return f"β
API Key loaded successfully\nKey starts with: {API_KEY[:10]}..."
|
151 |
+
else:
|
152 |
+
return "β Replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual API key"
|
153 |
+
|
154 |
+
# Initialize LLM on startup
|
155 |
+
init_status, _ = initialize_llm()
|
156 |
+
|
157 |
+
# Create Gradio interface
|
158 |
+
with gr.Blocks(title="π CSV Chat with Groq + PandasAI", theme=gr.themes.Soft()) as demo:
|
159 |
+
gr.Markdown("# π Chat with Your CSV using PandasAI + Groq")
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column(scale=2):
|
163 |
+
# File upload section
|
164 |
+
gr.Markdown("## π Upload CSV File")
|
165 |
+
file_input = gr.File(
|
166 |
+
label="Upload your CSV file",
|
167 |
+
file_types=[".csv"],
|
168 |
+
type="filepath"
|
169 |
+
)
|
170 |
+
|
171 |
+
upload_status = gr.Textbox(
|
172 |
+
label="Upload Status",
|
173 |
+
interactive=False,
|
174 |
+
value=init_status
|
175 |
+
)
|
176 |
+
|
177 |
+
# Data preview section
|
178 |
+
gr.Markdown("## π Data Preview")
|
179 |
+
data_preview = gr.HTML(label="Data Preview")
|
180 |
+
data_info = gr.Markdown()
|
181 |
+
|
182 |
+
with gr.Column(scale=1):
|
183 |
+
# Debug and help section
|
184 |
+
gr.Markdown("## π§ Debug Info")
|
185 |
+
debug_btn = gr.Button("Show Debug Info")
|
186 |
+
debug_info = gr.Textbox(label="Debug Information", interactive=False)
|
187 |
+
|
188 |
+
gr.Markdown("## π Example Queries")
|
189 |
+
gr.Markdown("""
|
190 |
+
- "Show me the first 10 rows"
|
191 |
+
- "What are the column names?"
|
192 |
+
- "Create a histogram of [column_name]"
|
193 |
+
- "Show me the summary statistics"
|
194 |
+
- "Plot the top 5 values in [column_name]"
|
195 |
+
- "Create a bar chart showing [column1] vs [column2]"
|
196 |
+
""")
|
197 |
+
|
198 |
+
# Chat section
|
199 |
+
gr.Markdown("## π¬ Chat with Your Data")
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
query_input = gr.Textbox(
|
203 |
+
label="Ask a question or request a chart",
|
204 |
+
placeholder="What would you like to know about your data?",
|
205 |
+
lines=3,
|
206 |
+
scale=4
|
207 |
+
)
|
208 |
+
submit_btn = gr.Button("Submit Query", variant="primary", scale=1)
|
209 |
+
|
210 |
+
# Results section
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column():
|
213 |
+
response_output = gr.Markdown(label="Response")
|
214 |
+
with gr.Column():
|
215 |
+
chart_output = gr.Image(label="Generated Chart", type="filepath")
|
216 |
+
|
217 |
+
# Event handlers
|
218 |
+
file_input.change(
|
219 |
+
fn=process_csv,
|
220 |
+
inputs=[file_input],
|
221 |
+
outputs=[upload_status, data_preview, data_info]
|
222 |
+
)
|
223 |
+
|
224 |
+
debug_btn.click(
|
225 |
+
fn=get_debug_info,
|
226 |
+
outputs=[debug_info]
|
227 |
+
)
|
228 |
+
|
229 |
+
submit_btn.click(
|
230 |
+
fn=chat_with_data,
|
231 |
+
inputs=[query_input],
|
232 |
+
outputs=[response_output, chart_output]
|
233 |
+
)
|
234 |
+
|
235 |
+
query_input.submit(
|
236 |
+
fn=chat_with_data,
|
237 |
+
inputs=[query_input],
|
238 |
+
outputs=[response_output, chart_output]
|
239 |
+
)
|
240 |
+
|
241 |
+
# Launch the app
|
242 |
+
if __name__ == "__main__":
|
243 |
+
demo.launch(
|
244 |
+
share=False, # Set to True if you want a public link
|
245 |
+
debug=True,
|
246 |
+
show_error=True
|
247 |
+
)
|