Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,247 +1,180 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
-
import matplotlib
|
5 |
-
matplotlib.use('Agg') # Use non-interactive backend
|
6 |
import os
|
7 |
import tempfile
|
8 |
-
import
|
9 |
-
from io import BytesIO
|
10 |
from pandasai import SmartDataframe
|
11 |
from langchain_groq.chat_models import ChatGroq
|
|
|
|
|
|
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
# Global variables to store data
|
17 |
-
current_df = None
|
18 |
-
llm = None
|
19 |
|
20 |
-
def
|
21 |
-
"""
|
22 |
-
|
|
|
23 |
try:
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
|
|
27 |
llm = ChatGroq(
|
28 |
-
model_name="mixtral-8x7b-32768",
|
29 |
-
api_key=
|
30 |
temperature=0
|
31 |
)
|
32 |
-
return "β
Groq LLM initialized successfully", llm
|
33 |
-
except Exception as e:
|
34 |
-
return f"β Failed to initialize Groq LLM: {str(e)}", None
|
35 |
-
|
36 |
-
def process_csv(file):
|
37 |
-
"""Process uploaded CSV file"""
|
38 |
-
global current_df
|
39 |
-
|
40 |
-
if file is None:
|
41 |
-
return "No file uploaded", None, None
|
42 |
-
|
43 |
-
try:
|
44 |
-
# Read the CSV file
|
45 |
-
current_df = pd.read_csv(file.name)
|
46 |
|
47 |
-
# Create
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
**File Info:**
|
53 |
-
- Shape: {current_df.shape[0]} rows Γ {current_df.shape[1]} columns
|
54 |
-
- Columns: {', '.join(current_df.columns.tolist())}
|
55 |
-
"""
|
56 |
|
57 |
-
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
def chat_with_data(query):
|
63 |
-
"""Process user query and return response"""
|
64 |
-
global current_df, llm
|
65 |
-
|
66 |
-
if current_df is None:
|
67 |
-
return "β Please upload a CSV file first", None
|
68 |
-
|
69 |
-
if llm is None:
|
70 |
-
status, _ = initialize_llm()
|
71 |
-
if llm is None:
|
72 |
-
return status, None
|
73 |
-
|
74 |
-
if not query.strip():
|
75 |
-
return "β Please enter a query", None
|
76 |
-
|
77 |
-
try:
|
78 |
-
# Create temporary directory for charts
|
79 |
-
temp_dir = tempfile.mkdtemp()
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
"llm": llm,
|
86 |
-
"verbose": True,
|
87 |
-
"save_charts": True,
|
88 |
-
"save_charts_path": temp_dir,
|
89 |
-
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"]
|
90 |
-
}
|
91 |
-
)
|
92 |
|
93 |
-
#
|
94 |
-
result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
else:
|
127 |
-
# Other types of results
|
128 |
-
return f"π **Result:**\n{str(result)}", None
|
129 |
-
|
130 |
-
except Exception as e:
|
131 |
-
error_msg = f"β Error: {str(e)}"
|
132 |
|
133 |
-
#
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
#
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
with gr.Row():
|
162 |
-
with gr.Column(scale=2):
|
163 |
-
# File upload section
|
164 |
-
gr.Markdown("## π Upload CSV File")
|
165 |
-
file_input = gr.File(
|
166 |
-
label="Upload your CSV file",
|
167 |
-
file_types=[".csv"],
|
168 |
-
type="filepath"
|
169 |
-
)
|
170 |
-
|
171 |
-
upload_status = gr.Textbox(
|
172 |
-
label="Upload Status",
|
173 |
-
interactive=False,
|
174 |
-
value=init_status
|
175 |
-
)
|
176 |
-
|
177 |
-
# Data preview section
|
178 |
-
gr.Markdown("## π Data Preview")
|
179 |
-
data_preview = gr.HTML(label="Data Preview")
|
180 |
-
data_info = gr.Markdown()
|
181 |
-
|
182 |
-
with gr.Column(scale=1):
|
183 |
-
# Debug and help section
|
184 |
-
gr.Markdown("## π§ Debug Info")
|
185 |
-
debug_btn = gr.Button("Show Debug Info")
|
186 |
-
debug_info = gr.Textbox(label="Debug Information", interactive=False)
|
187 |
-
|
188 |
-
gr.Markdown("## π Example Queries")
|
189 |
-
gr.Markdown("""
|
190 |
-
- "Show me the first 10 rows"
|
191 |
-
- "What are the column names?"
|
192 |
-
- "Create a histogram of [column_name]"
|
193 |
-
- "Show me the summary statistics"
|
194 |
-
- "Plot the top 5 values in [column_name]"
|
195 |
-
- "Create a bar chart showing [column1] vs [column2]"
|
196 |
-
""")
|
197 |
-
|
198 |
-
# Chat section
|
199 |
-
gr.Markdown("## π¬ Chat with Your Data")
|
200 |
-
|
201 |
-
with gr.Row():
|
202 |
-
query_input = gr.Textbox(
|
203 |
-
label="Ask a question or request a chart",
|
204 |
-
placeholder="What would you like to know about your data?",
|
205 |
-
lines=3,
|
206 |
-
scale=4
|
207 |
)
|
208 |
-
submit_btn = gr.Button("Submit Query", variant="primary", scale=1)
|
209 |
-
|
210 |
-
# Results section
|
211 |
-
with gr.Row():
|
212 |
-
with gr.Column():
|
213 |
-
response_output = gr.Markdown(label="Response")
|
214 |
-
with gr.Column():
|
215 |
-
chart_output = gr.Image(label="Generated Chart", type="filepath")
|
216 |
-
|
217 |
-
# Event handlers
|
218 |
-
file_input.change(
|
219 |
-
fn=process_csv,
|
220 |
-
inputs=[file_input],
|
221 |
-
outputs=[upload_status, data_preview, data_info]
|
222 |
-
)
|
223 |
-
|
224 |
-
debug_btn.click(
|
225 |
-
fn=get_debug_info,
|
226 |
-
outputs=[debug_info]
|
227 |
-
)
|
228 |
-
|
229 |
-
submit_btn.click(
|
230 |
-
fn=chat_with_data,
|
231 |
-
inputs=[query_input],
|
232 |
-
outputs=[response_output, chart_output]
|
233 |
-
)
|
234 |
|
235 |
-
|
236 |
-
fn=chat_with_data,
|
237 |
-
inputs=[query_input],
|
238 |
-
outputs=[response_output, chart_output]
|
239 |
-
)
|
240 |
|
241 |
-
# Launch the app
|
242 |
if __name__ == "__main__":
|
|
|
|
|
243 |
demo.launch(
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
247 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
|
|
|
|
|
|
3 |
import os
|
4 |
import tempfile
|
5 |
+
import matplotlib.pyplot as plt
|
|
|
6 |
from pandasai import SmartDataframe
|
7 |
from langchain_groq.chat_models import ChatGroq
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import io
|
10 |
+
import base64
|
11 |
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
def process_data(file, query, api_key):
|
16 |
+
"""
|
17 |
+
Process the uploaded CSV file with the given query using PandasAI
|
18 |
+
"""
|
19 |
try:
|
20 |
+
# Validate inputs
|
21 |
+
if file is None:
|
22 |
+
return "Please upload a CSV file.", None
|
23 |
+
|
24 |
+
if not query.strip():
|
25 |
+
return "Please enter a query.", None
|
26 |
+
|
27 |
+
if not api_key.strip():
|
28 |
+
return "Please enter your Groq API key.", None
|
29 |
+
|
30 |
+
# Read the CSV file
|
31 |
+
df_data = pd.read_csv(file.name)
|
32 |
|
33 |
+
# Initialize Groq LLM
|
34 |
llm = ChatGroq(
|
35 |
+
model_name="mixtral-8x7b-32768", # Using a more stable model
|
36 |
+
api_key=api_key.strip(),
|
37 |
temperature=0
|
38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
# Create SmartDataframe
|
41 |
+
df = SmartDataframe(df_data, config={
|
42 |
+
"llm": llm,
|
43 |
+
"save_charts": True,
|
44 |
+
"save_charts_path": tempfile.gettempdir(),
|
45 |
+
"open_charts": False,
|
46 |
+
"enable_cache": False
|
47 |
+
})
|
48 |
|
49 |
+
# Process the query
|
50 |
+
result = df.chat(query)
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
# Handle different types of results
|
53 |
+
if result is None:
|
54 |
+
return "No result returned. Please try a different query.", None
|
55 |
|
56 |
+
# Check if result is a plot/chart
|
57 |
+
chart_path = None
|
58 |
+
chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
if chart_files:
|
61 |
+
# Get the most recent chart file
|
62 |
+
chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True)
|
63 |
+
chart_path = os.path.join(tempfile.gettempdir(), chart_files[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# Format the text result
|
66 |
+
if isinstance(result, pd.DataFrame):
|
67 |
+
result_text = f"Query Result:\n\n{result.to_string()}"
|
68 |
+
elif isinstance(result, (int, float)):
|
69 |
+
result_text = f"Query Result: {result}"
|
70 |
+
elif isinstance(result, str):
|
71 |
+
result_text = f"Query Result:\n{result}"
|
72 |
+
else:
|
73 |
+
result_text = f"Query Result:\n{str(result)}"
|
74 |
|
75 |
+
return result_text, chart_path
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
error_msg = f"Error processing query: {str(e)}"
|
79 |
+
return error_msg, None
|
80 |
+
|
81 |
+
def create_interface():
|
82 |
+
"""
|
83 |
+
Create the Gradio interface
|
84 |
+
"""
|
85 |
+
with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo:
|
86 |
+
gr.Markdown(
|
87 |
+
"""
|
88 |
+
# π PandasAI Data Analysis with Groq
|
89 |
+
|
90 |
+
Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly.
|
91 |
|
92 |
+
**Instructions:**
|
93 |
+
1. Get your Groq API key from [https://console.groq.com/keys](https://console.groq.com/keys)
|
94 |
+
2. Upload your CSV file
|
95 |
+
3. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region")
|
96 |
+
4. Click Submit to get results
|
97 |
+
"""
|
98 |
+
)
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
with gr.Column(scale=1):
|
102 |
+
# Input components
|
103 |
+
api_key_input = gr.Textbox(
|
104 |
+
label="Groq API Key",
|
105 |
+
placeholder="Enter your Groq API key here...",
|
106 |
+
type="password",
|
107 |
+
info="Your API key is not stored and only used for this session"
|
108 |
+
)
|
109 |
|
110 |
+
file_input = gr.File(
|
111 |
+
label="Upload CSV File",
|
112 |
+
file_types=[".csv"],
|
113 |
+
info="Upload your CSV data file"
|
114 |
+
)
|
115 |
|
116 |
+
query_input = gr.Textbox(
|
117 |
+
label="Your Query",
|
118 |
+
placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'",
|
119 |
+
lines=3,
|
120 |
+
info="Ask questions about your data or request visualizations"
|
121 |
+
)
|
122 |
|
123 |
+
submit_btn = gr.Button("π Submit Query", variant="primary")
|
124 |
|
125 |
+
with gr.Column(scale=2):
|
126 |
+
# Output components
|
127 |
+
result_output = gr.Textbox(
|
128 |
+
label="Analysis Result",
|
129 |
+
lines=10,
|
130 |
+
interactive=False,
|
131 |
+
show_copy_button=True
|
132 |
+
)
|
133 |
|
134 |
+
chart_output = gr.Image(
|
135 |
+
label="Generated Visualization",
|
136 |
+
show_label=True
|
137 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
+
# Example queries
|
140 |
+
gr.Markdown(
|
141 |
+
"""
|
142 |
+
### π‘ Example Queries:
|
143 |
+
- "Which are the top 5 countries by population?"
|
144 |
+
- "Create a bar plot of the top 10 countries by population"
|
145 |
+
- "Show me a pie chart of the top 5 countries"
|
146 |
+
- "Calculate the total population of the top 3 countries"
|
147 |
+
- "What is the average population across all countries?"
|
148 |
+
- "Create a scatter plot showing the relationship between two columns"
|
149 |
+
"""
|
150 |
+
)
|
151 |
+
|
152 |
+
# Event handlers
|
153 |
+
submit_btn.click(
|
154 |
+
fn=process_data,
|
155 |
+
inputs=[file_input, query_input, api_key_input],
|
156 |
+
outputs=[result_output, chart_output],
|
157 |
+
show_progress=True
|
158 |
+
)
|
159 |
+
|
160 |
+
# Allow Enter key to submit
|
161 |
+
query_input.submit(
|
162 |
+
fn=process_data,
|
163 |
+
inputs=[file_input, query_input, api_key_input],
|
164 |
+
outputs=[result_output, chart_output],
|
165 |
+
show_progress=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
return demo
|
|
|
|
|
|
|
|
|
169 |
|
|
|
170 |
if __name__ == "__main__":
|
171 |
+
# Create and launch the interface
|
172 |
+
demo = create_interface()
|
173 |
demo.launch(
|
174 |
+
server_name="0.0.0.0",
|
175 |
+
server_port=7860,
|
176 |
+
share=False,
|
177 |
+
show_error=True,
|
178 |
+
show_tips=True,
|
179 |
+
enable_queue=True
|
180 |
)
|