srivatsavdamaraju commited on
Commit
821675b
Β·
verified Β·
1 Parent(s): 0e41979

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -214
app.py CHANGED
@@ -1,247 +1,180 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import matplotlib
5
- matplotlib.use('Agg') # Use non-interactive backend
6
  import os
7
  import tempfile
8
- import base64
9
- from io import BytesIO
10
  from pandasai import SmartDataframe
11
  from langchain_groq.chat_models import ChatGroq
 
 
 
12
 
13
- # HARDCODED API KEY - REPLACE WITH YOUR ACTUAL KEY
14
- API_KEY = "gsk_YOUR_ACTUAL_API_KEY_HERE" # Replace with your real API key
15
-
16
- # Global variables to store data
17
- current_df = None
18
- llm = None
19
 
20
- def initialize_llm():
21
- """Initialize the Groq LLM"""
22
- global llm
 
23
  try:
24
- if API_KEY == "gsk_YOUR_ACTUAL_API_KEY_HERE":
25
- return "❌ Please replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual Groq API key", None
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  llm = ChatGroq(
28
- model_name="mixtral-8x7b-32768",
29
- api_key=API_KEY,
30
  temperature=0
31
  )
32
- return "βœ… Groq LLM initialized successfully", llm
33
- except Exception as e:
34
- return f"❌ Failed to initialize Groq LLM: {str(e)}", None
35
-
36
- def process_csv(file):
37
- """Process uploaded CSV file"""
38
- global current_df
39
-
40
- if file is None:
41
- return "No file uploaded", None, None
42
-
43
- try:
44
- # Read the CSV file
45
- current_df = pd.read_csv(file.name)
46
 
47
- # Create preview
48
- preview = current_df.head().to_html(classes='table table-striped', table_id='data-preview')
 
 
 
 
 
 
49
 
50
- # Create info
51
- info = f"""
52
- **File Info:**
53
- - Shape: {current_df.shape[0]} rows Γ— {current_df.shape[1]} columns
54
- - Columns: {', '.join(current_df.columns.tolist())}
55
- """
56
 
57
- return "βœ… CSV file loaded successfully", preview, info
 
 
58
 
59
- except Exception as e:
60
- return f"❌ Error reading CSV: {str(e)}", None, None
61
-
62
- def chat_with_data(query):
63
- """Process user query and return response"""
64
- global current_df, llm
65
-
66
- if current_df is None:
67
- return "❌ Please upload a CSV file first", None
68
-
69
- if llm is None:
70
- status, _ = initialize_llm()
71
- if llm is None:
72
- return status, None
73
-
74
- if not query.strip():
75
- return "❌ Please enter a query", None
76
-
77
- try:
78
- # Create temporary directory for charts
79
- temp_dir = tempfile.mkdtemp()
80
 
81
- # Create SmartDataframe
82
- sdf = SmartDataframe(
83
- current_df,
84
- config={
85
- "llm": llm,
86
- "verbose": True,
87
- "save_charts": True,
88
- "save_charts_path": temp_dir,
89
- "custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"]
90
- }
91
- )
92
 
93
- # Process the query
94
- result = sdf.chat(query)
 
 
 
 
 
 
 
95
 
96
- # Handle different types of results
97
- if isinstance(result, str):
98
- # Text response
99
- return f"πŸ“’ **Response:**\n{result}", None
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- elif hasattr(result, 'savefig'):
102
- # Matplotlib figure
103
- try:
104
- # Save figure to bytes
105
- img_buffer = BytesIO()
106
- result.savefig(img_buffer, format='png', dpi=150, bbox_inches='tight')
107
- img_buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
108
 
109
- # Save to temporary file for Gradio
110
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
111
- with open(temp_file.name, 'wb') as f:
112
- f.write(img_buffer.getvalue())
 
113
 
114
- plt.close(result) # Close the figure to free memory
 
 
 
 
 
115
 
116
- return "πŸ“ˆ **Chart Generated:**", temp_file.name
117
 
118
- except Exception as chart_error:
119
- return f"❌ Error saving chart: {str(chart_error)}", None
 
 
 
 
 
 
120
 
121
- elif isinstance(result, pd.DataFrame):
122
- # DataFrame result
123
- html_table = result.to_html(classes='table table-striped', max_rows=100)
124
- return f"πŸ“Š **Data Result:**\n{html_table}", None
125
-
126
- else:
127
- # Other types of results
128
- return f"πŸ“Š **Result:**\n{str(result)}", None
129
-
130
- except Exception as e:
131
- error_msg = f"❌ Error: {str(e)}"
132
 
133
- # Provide specific error guidance
134
- if "403" in str(e):
135
- error_msg += "\n\nπŸ” **403 Forbidden Error** - This usually means:\n"
136
- error_msg += "- Invalid API key\n"
137
- error_msg += "- API key doesn't have permission for this model\n"
138
- error_msg += "- Rate limit exceeded\n"
139
- error_msg += "- Model name is incorrect"
140
- elif "rate limit" in str(e).lower():
141
- error_msg += "\n\n⏰ **Rate Limit** - Please wait a moment before trying again"
142
- elif "timeout" in str(e).lower():
143
- error_msg += "\n\n⏱️ **Timeout** - The query took too long. Try a simpler request"
144
-
145
- return error_msg, None
146
-
147
- def get_debug_info():
148
- """Get debug information"""
149
- if API_KEY and API_KEY != "gsk_YOUR_ACTUAL_API_KEY_HERE":
150
- return f"βœ… API Key loaded successfully\nKey starts with: {API_KEY[:10]}..."
151
- else:
152
- return "❌ Replace 'gsk_YOUR_ACTUAL_API_KEY_HERE' with your actual API key"
153
-
154
- # Initialize LLM on startup
155
- init_status, _ = initialize_llm()
156
-
157
- # Create Gradio interface
158
- with gr.Blocks(title="πŸ“Š CSV Chat with Groq + PandasAI", theme=gr.themes.Soft()) as demo:
159
- gr.Markdown("# πŸ“Š Chat with Your CSV using PandasAI + Groq")
160
-
161
- with gr.Row():
162
- with gr.Column(scale=2):
163
- # File upload section
164
- gr.Markdown("## πŸ“ Upload CSV File")
165
- file_input = gr.File(
166
- label="Upload your CSV file",
167
- file_types=[".csv"],
168
- type="filepath"
169
- )
170
-
171
- upload_status = gr.Textbox(
172
- label="Upload Status",
173
- interactive=False,
174
- value=init_status
175
- )
176
-
177
- # Data preview section
178
- gr.Markdown("## πŸ“‹ Data Preview")
179
- data_preview = gr.HTML(label="Data Preview")
180
- data_info = gr.Markdown()
181
-
182
- with gr.Column(scale=1):
183
- # Debug and help section
184
- gr.Markdown("## πŸ”§ Debug Info")
185
- debug_btn = gr.Button("Show Debug Info")
186
- debug_info = gr.Textbox(label="Debug Information", interactive=False)
187
-
188
- gr.Markdown("## πŸ“ Example Queries")
189
- gr.Markdown("""
190
- - "Show me the first 10 rows"
191
- - "What are the column names?"
192
- - "Create a histogram of [column_name]"
193
- - "Show me the summary statistics"
194
- - "Plot the top 5 values in [column_name]"
195
- - "Create a bar chart showing [column1] vs [column2]"
196
- """)
197
-
198
- # Chat section
199
- gr.Markdown("## πŸ’¬ Chat with Your Data")
200
-
201
- with gr.Row():
202
- query_input = gr.Textbox(
203
- label="Ask a question or request a chart",
204
- placeholder="What would you like to know about your data?",
205
- lines=3,
206
- scale=4
207
  )
208
- submit_btn = gr.Button("Submit Query", variant="primary", scale=1)
209
-
210
- # Results section
211
- with gr.Row():
212
- with gr.Column():
213
- response_output = gr.Markdown(label="Response")
214
- with gr.Column():
215
- chart_output = gr.Image(label="Generated Chart", type="filepath")
216
-
217
- # Event handlers
218
- file_input.change(
219
- fn=process_csv,
220
- inputs=[file_input],
221
- outputs=[upload_status, data_preview, data_info]
222
- )
223
-
224
- debug_btn.click(
225
- fn=get_debug_info,
226
- outputs=[debug_info]
227
- )
228
-
229
- submit_btn.click(
230
- fn=chat_with_data,
231
- inputs=[query_input],
232
- outputs=[response_output, chart_output]
233
- )
234
 
235
- query_input.submit(
236
- fn=chat_with_data,
237
- inputs=[query_input],
238
- outputs=[response_output, chart_output]
239
- )
240
 
241
- # Launch the app
242
  if __name__ == "__main__":
 
 
243
  demo.launch(
244
- share=False, # Set to True if you want a public link
245
- debug=True,
246
- show_error=True
 
 
 
247
  )
 
1
  import gradio as gr
2
  import pandas as pd
 
 
 
3
  import os
4
  import tempfile
5
+ import matplotlib.pyplot as plt
 
6
  from pandasai import SmartDataframe
7
  from langchain_groq.chat_models import ChatGroq
8
+ from dotenv import load_dotenv
9
+ import io
10
+ import base64
11
 
12
+ # Load environment variables
13
+ load_dotenv()
 
 
 
 
14
 
15
+ def process_data(file, query, api_key):
16
+ """
17
+ Process the uploaded CSV file with the given query using PandasAI
18
+ """
19
  try:
20
+ # Validate inputs
21
+ if file is None:
22
+ return "Please upload a CSV file.", None
23
+
24
+ if not query.strip():
25
+ return "Please enter a query.", None
26
+
27
+ if not api_key.strip():
28
+ return "Please enter your Groq API key.", None
29
+
30
+ # Read the CSV file
31
+ df_data = pd.read_csv(file.name)
32
 
33
+ # Initialize Groq LLM
34
  llm = ChatGroq(
35
+ model_name="mixtral-8x7b-32768", # Using a more stable model
36
+ api_key=api_key.strip(),
37
  temperature=0
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Create SmartDataframe
41
+ df = SmartDataframe(df_data, config={
42
+ "llm": llm,
43
+ "save_charts": True,
44
+ "save_charts_path": tempfile.gettempdir(),
45
+ "open_charts": False,
46
+ "enable_cache": False
47
+ })
48
 
49
+ # Process the query
50
+ result = df.chat(query)
 
 
 
 
51
 
52
+ # Handle different types of results
53
+ if result is None:
54
+ return "No result returned. Please try a different query.", None
55
 
56
+ # Check if result is a plot/chart
57
+ chart_path = None
58
+ chart_files = [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(('.png', '.jpg', '.jpeg'))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ if chart_files:
61
+ # Get the most recent chart file
62
+ chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(tempfile.gettempdir(), x)), reverse=True)
63
+ chart_path = os.path.join(tempfile.gettempdir(), chart_files[0])
 
 
 
 
 
 
 
64
 
65
+ # Format the text result
66
+ if isinstance(result, pd.DataFrame):
67
+ result_text = f"Query Result:\n\n{result.to_string()}"
68
+ elif isinstance(result, (int, float)):
69
+ result_text = f"Query Result: {result}"
70
+ elif isinstance(result, str):
71
+ result_text = f"Query Result:\n{result}"
72
+ else:
73
+ result_text = f"Query Result:\n{str(result)}"
74
 
75
+ return result_text, chart_path
76
+
77
+ except Exception as e:
78
+ error_msg = f"Error processing query: {str(e)}"
79
+ return error_msg, None
80
+
81
+ def create_interface():
82
+ """
83
+ Create the Gradio interface
84
+ """
85
+ with gr.Blocks(title="PandasAI with Groq", theme=gr.themes.Soft()) as demo:
86
+ gr.Markdown(
87
+ """
88
+ # πŸ“Š PandasAI Data Analysis with Groq
89
+
90
+ Upload a CSV file and ask questions about your data. The AI will analyze and visualize your data accordingly.
91
 
92
+ **Instructions:**
93
+ 1. Get your Groq API key from [https://console.groq.com/keys](https://console.groq.com/keys)
94
+ 2. Upload your CSV file
95
+ 3. Enter your query (e.g., "Show top 5 countries by population", "Create a bar plot of sales by region")
96
+ 4. Click Submit to get results
97
+ """
98
+ )
99
+
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ # Input components
103
+ api_key_input = gr.Textbox(
104
+ label="Groq API Key",
105
+ placeholder="Enter your Groq API key here...",
106
+ type="password",
107
+ info="Your API key is not stored and only used for this session"
108
+ )
109
 
110
+ file_input = gr.File(
111
+ label="Upload CSV File",
112
+ file_types=[".csv"],
113
+ info="Upload your CSV data file"
114
+ )
115
 
116
+ query_input = gr.Textbox(
117
+ label="Your Query",
118
+ placeholder="e.g., 'Which are the top 5 countries by population?' or 'Create a bar plot of the top 5 countries'",
119
+ lines=3,
120
+ info="Ask questions about your data or request visualizations"
121
+ )
122
 
123
+ submit_btn = gr.Button("πŸš€ Submit Query", variant="primary")
124
 
125
+ with gr.Column(scale=2):
126
+ # Output components
127
+ result_output = gr.Textbox(
128
+ label="Analysis Result",
129
+ lines=10,
130
+ interactive=False,
131
+ show_copy_button=True
132
+ )
133
 
134
+ chart_output = gr.Image(
135
+ label="Generated Visualization",
136
+ show_label=True
137
+ )
 
 
 
 
 
 
 
138
 
139
+ # Example queries
140
+ gr.Markdown(
141
+ """
142
+ ### πŸ’‘ Example Queries:
143
+ - "Which are the top 5 countries by population?"
144
+ - "Create a bar plot of the top 10 countries by population"
145
+ - "Show me a pie chart of the top 5 countries"
146
+ - "Calculate the total population of the top 3 countries"
147
+ - "What is the average population across all countries?"
148
+ - "Create a scatter plot showing the relationship between two columns"
149
+ """
150
+ )
151
+
152
+ # Event handlers
153
+ submit_btn.click(
154
+ fn=process_data,
155
+ inputs=[file_input, query_input, api_key_input],
156
+ outputs=[result_output, chart_output],
157
+ show_progress=True
158
+ )
159
+
160
+ # Allow Enter key to submit
161
+ query_input.submit(
162
+ fn=process_data,
163
+ inputs=[file_input, query_input, api_key_input],
164
+ outputs=[result_output, chart_output],
165
+ show_progress=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ return demo
 
 
 
 
169
 
 
170
  if __name__ == "__main__":
171
+ # Create and launch the interface
172
+ demo = create_interface()
173
  demo.launch(
174
+ server_name="0.0.0.0",
175
+ server_port=7860,
176
+ share=False,
177
+ show_error=True,
178
+ show_tips=True,
179
+ enable_queue=True
180
  )