Entz commited on
Commit
2c0711c
Β·
verified Β·
1 Parent(s): d3a543b

Upload 2 files

Browse files
Files changed (2) hide show
  1. backend.py +9 -46
  2. frontend.py +5 -21
backend.py CHANGED
@@ -8,32 +8,27 @@ import google.generativeai as genai
8
 
9
  app = FastAPI()
10
 
11
- # Load environment variables and configure Genai
12
  load_dotenv()
13
  genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
14
 
15
- # Define the schema for the incoming request
16
  class Query(BaseModel):
17
  question: str
18
  data_source: str
19
 
20
  def get_gemini_response(question, prompt):
21
- model = genai.GenerativeModel('gemini-1.5-pro') # https://ai.google.dev/pricing?authuser=1#1_5pro
22
  response = model.generate_content([prompt, question])
23
  return response.text
24
 
25
- # Update column and table names for the new dataset
26
  sql_cols_human = 'REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode'
27
  csv_columns_human = ['REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode']
28
  sql_cols = 'REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode'
29
- # csv_columns = ["REQUESTID", "DATETIMEINIT", "SOURCE", "DESCRIPTION", "REQCATEGORY", "STATUS", "REFERREDTO", "DATETIMECLOSED", "PROBADDRESS" "City", "State", "Ward", "Postcode"]
30
 
31
  def get_csv_columns():
32
  df = pd.read_csv('wandsworth_callcenter_sampled.csv')
33
  return df.columns.tolist()
34
 
35
  csv_columns = get_csv_columns()
36
- print(csv_columns)
37
 
38
  sql_prompt = f"""
39
  You are an expert in converting English questions to SQLite code!
@@ -54,8 +49,6 @@ Also, the SQL code should not have ''' in the beginning or at the end, and SQL w
54
  Ensure that you only generate valid SQLite database queries, not pandas or Python code.
55
  """
56
 
57
-
58
-
59
  csv_prompt = f"""
60
  You are an expert in analyzing CSV data and converting English questions to pandas query syntax.
61
  The CSV file is named 'wandsworth_callcenter_sampled.csv' and contains residents' call information in Wandsworth Council.
@@ -78,7 +71,6 @@ Please ensure:
78
  3. Provide only the pandas query syntax without any additional explanation or markdown formatting.
79
  Make sure to use only the columns that are available in the CSV file.
80
  Ensure that you only generate valid pandas queries. NO SQL or other types of code/syntax.
81
-
82
  """
83
 
84
  def execute_sql_query(query):
@@ -89,9 +81,7 @@ def execute_sql_query(query):
89
  result = cursor.fetchall()
90
  return result
91
  except sqlite3.Error as e:
92
- # Capture and explain SQL errors
93
  sql_error_message = str(e)
94
- # Send the error message back to Gemini for explanation
95
  error_prompt = f"""
96
  You are an expert SQL debugger and an assistant of the director. An error occurred while executing the following query:
97
  {query}
@@ -107,56 +97,30 @@ def execute_sql_query(query):
107
  finally:
108
  conn.close()
109
 
110
-
111
-
112
-
113
  def execute_pandas_query(query):
114
  df = pd.read_csv('wandsworth_callcenter_sampled.csv')
115
- df.columns = df.columns.str.upper() # Normalize column names to uppercase
116
- print(f"df is loaded. The first line is: {df.head(1)}")
117
-
118
- # Remove code block indicators (e.g., ```python and ```)
119
  query = query.replace("```python", "").replace("```", "").strip()
120
-
121
- # Split query into lines
122
- query_lines = query.split("\n") # Split into individual statements
123
  try:
124
  result = None
125
- exec_context = {'df': df, 'pd': pd} # Execution context for exec()
126
  for line in query_lines:
127
- line = line.strip() # Remove extra spaces
128
- if line: # Skip empty lines
129
- print(f"Executing line: {line}")
130
- exec(line, exec_context) # Execute each line in the context
131
-
132
- # Retrieve the final result if the last line is a statement
133
- result = eval(query_lines[-1].strip(), exec_context) # Evaluate the last line for the result
134
-
135
- print(f"Query Result Before Serialization: {result}")
136
-
137
- # Handle DataFrame results
138
  if isinstance(result, pd.DataFrame):
139
- # Replace NaN and infinite values with JSON-compliant values
140
  result = result.replace([float('inf'), -float('inf')], None).fillna(value="N/A")
141
  return result.to_dict(orient='records')
142
-
143
- # Handle Series results
144
  elif isinstance(result, pd.Series):
145
  result = result.replace([float('inf'), -float('inf')], None).fillna(value="N/A")
146
  return result.to_dict()
147
-
148
- # Handle scalar results
149
  else:
150
  return result
151
-
152
  except Exception as e:
153
- print(f"Error: {e}")
154
  raise HTTPException(status_code=400, detail=f"Pandas Error: {str(e)}")
155
 
156
-
157
-
158
-
159
-
160
  @app.post("/query")
161
  async def process_query(query: Query):
162
  if query.data_source == "SQL Database":
@@ -167,9 +131,8 @@ async def process_query(query: Query):
167
  except HTTPException as e:
168
  error_detail = e.detail
169
  return {"query": ai_response, "error": error_detail["error"], "explanation": error_detail["explanation"]}
170
- else: # CSV Data
171
  ai_response = get_gemini_response(query.question, csv_prompt)
172
- print(f"\n\nai_response: {ai_response}")
173
  try:
174
  result = execute_pandas_query(ai_response)
175
  return {"query": ai_response, "result": result, "columns": csv_columns}
 
8
 
9
  app = FastAPI()
10
 
 
11
  load_dotenv()
12
  genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
13
 
 
14
  class Query(BaseModel):
15
  question: str
16
  data_source: str
17
 
18
  def get_gemini_response(question, prompt):
19
+ model = genai.GenerativeModel('gemini-1.5-pro')
20
  response = model.generate_content([prompt, question])
21
  return response.text
22
 
 
23
  sql_cols_human = 'REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode'
24
  csv_columns_human = ['REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode']
25
  sql_cols = 'REQUESTID', 'DATETIMEINIT', 'SOURCE', 'DESCRIPTION', 'REQCATEGORY', 'STATUS', 'REFERREDTO', 'DATETIMECLOSED', 'City', 'State', 'Ward', 'Postcode'
 
26
 
27
  def get_csv_columns():
28
  df = pd.read_csv('wandsworth_callcenter_sampled.csv')
29
  return df.columns.tolist()
30
 
31
  csv_columns = get_csv_columns()
 
32
 
33
  sql_prompt = f"""
34
  You are an expert in converting English questions to SQLite code!
 
49
  Ensure that you only generate valid SQLite database queries, not pandas or Python code.
50
  """
51
 
 
 
52
  csv_prompt = f"""
53
  You are an expert in analyzing CSV data and converting English questions to pandas query syntax.
54
  The CSV file is named 'wandsworth_callcenter_sampled.csv' and contains residents' call information in Wandsworth Council.
 
71
  3. Provide only the pandas query syntax without any additional explanation or markdown formatting.
72
  Make sure to use only the columns that are available in the CSV file.
73
  Ensure that you only generate valid pandas queries. NO SQL or other types of code/syntax.
 
74
  """
75
 
76
  def execute_sql_query(query):
 
81
  result = cursor.fetchall()
82
  return result
83
  except sqlite3.Error as e:
 
84
  sql_error_message = str(e)
 
85
  error_prompt = f"""
86
  You are an expert SQL debugger and an assistant of the director. An error occurred while executing the following query:
87
  {query}
 
97
  finally:
98
  conn.close()
99
 
 
 
 
100
  def execute_pandas_query(query):
101
  df = pd.read_csv('wandsworth_callcenter_sampled.csv')
102
+ df.columns = df.columns.str.upper()
 
 
 
103
  query = query.replace("```python", "").replace("```", "").strip()
104
+ query_lines = query.split("\n")
 
 
105
  try:
106
  result = None
107
+ exec_context = {'df': df, 'pd': pd}
108
  for line in query_lines:
109
+ line = line.strip()
110
+ if line:
111
+ exec(line, exec_context)
112
+ result = eval(query_lines[-1].strip(), exec_context)
 
 
 
 
 
 
 
113
  if isinstance(result, pd.DataFrame):
 
114
  result = result.replace([float('inf'), -float('inf')], None).fillna(value="N/A")
115
  return result.to_dict(orient='records')
 
 
116
  elif isinstance(result, pd.Series):
117
  result = result.replace([float('inf'), -float('inf')], None).fillna(value="N/A")
118
  return result.to_dict()
 
 
119
  else:
120
  return result
 
121
  except Exception as e:
 
122
  raise HTTPException(status_code=400, detail=f"Pandas Error: {str(e)}")
123
 
 
 
 
 
124
  @app.post("/query")
125
  async def process_query(query: Query):
126
  if query.data_source == "SQL Database":
 
131
  except HTTPException as e:
132
  error_detail = e.detail
133
  return {"query": ai_response, "error": error_detail["error"], "explanation": error_detail["explanation"]}
134
+ else:
135
  ai_response = get_gemini_response(query.question, csv_prompt)
 
136
  try:
137
  result = execute_pandas_query(ai_response)
138
  return {"query": ai_response, "result": result, "columns": csv_columns}
frontend.py CHANGED
@@ -2,30 +2,26 @@ import streamlit as st
2
  import requests
3
  import pandas as pd
4
 
5
- # Page Configuration
6
  st.set_page_config(
7
- page_title="CallDataAI - Wandsworth Council Call Center Analysis",
8
  page_icon="πŸ“ž",
9
  layout="wide",
10
  initial_sidebar_state="expanded",
11
  )
12
 
13
- # Sidebar
14
  st.sidebar.title("πŸ“ž CallDataAI")
15
  st.sidebar.markdown(
16
  """
17
- **Welcome to CallDataAI**, your AI-powered assistant for analyzing Wandsworth Council's Call Center data. Use the menu below to:
18
  - Select the data source (SQL/CSV)
19
  - Run pre-defined or custom queries
20
  - Gain actionable insights
21
  """
22
  )
23
 
24
- # Data source selection
25
  st.sidebar.markdown("### Select Data Source:")
26
  data_source = st.sidebar.radio("", ('SQL Database', 'CSV Database'))
27
 
28
- # Common queries section
29
  st.sidebar.markdown("### Common Queries:")
30
  common_queries = {
31
  'SQL Database': [
@@ -45,22 +41,17 @@ common_queries = {
45
  }
46
 
47
  for idx, query in enumerate(common_queries[data_source]):
48
- if st.sidebar.button(query, key=f"query_button_{idx}"): # Add unique key
49
  st.session_state["common_query"] = query
50
 
51
-
52
-
53
-
54
- # Title and Description
55
- st.title("πŸ“ž CallDataAI - Wandsworth Council Call Center Analysis")
56
  st.markdown(
57
  """
58
- **CallDataAI** is an AI-powered chatbot designed for analyzing Wandsworth Council's Call Center data.
59
  Input natural language queries to explore the data and gain actionable insights.
60
  """
61
  )
62
 
63
- # Input Section
64
  with st.container():
65
  st.markdown("### Enter Your Question")
66
  question = st.text_input(
@@ -68,24 +59,19 @@ with st.container():
68
  )
69
  submit = st.button("Submit", type="primary")
70
 
71
- # Main Content
72
  if submit:
73
- # Send request to FastAPI backend
74
  with st.spinner("Processing your request..."):
75
  response = requests.post(
76
  "http://localhost:8000/query", json={"question": question, "data_source": data_source}
77
  )
78
 
79
- # Handle response
80
  if response.status_code == 200:
81
  data = response.json()
82
 
83
- # Error Handling
84
  if "error" in data:
85
  with st.expander("Error Explanation"):
86
  st.error(data["explanation"])
87
 
88
- # Display Results
89
  else:
90
  col1, col2 = st.columns(2)
91
 
@@ -117,7 +103,6 @@ if submit:
117
  st.markdown("### Available CSV Columns")
118
  st.write(data["columns"])
119
 
120
- # Update chat history in session state
121
  if "chat_history" not in st.session_state:
122
  st.session_state["chat_history"] = []
123
 
@@ -127,7 +112,6 @@ if submit:
127
  else:
128
  st.error(f"Error processing your request: {response.text}")
129
 
130
- # Chat History Section
131
  with st.container():
132
  st.markdown("### Chat History")
133
  if "chat_history" in st.session_state:
 
2
  import requests
3
  import pandas as pd
4
 
 
5
  st.set_page_config(
6
+ page_title="CallDataAI - Wandsworth Council NetCall Analysis",
7
  page_icon="πŸ“ž",
8
  layout="wide",
9
  initial_sidebar_state="expanded",
10
  )
11
 
 
12
  st.sidebar.title("πŸ“ž CallDataAI")
13
  st.sidebar.markdown(
14
  """
15
+ **Welcome to CallDataAI**, your AI-powered assistant for analyzing Wandsworth Council's NetCall data. Use the menu below to:
16
  - Select the data source (SQL/CSV)
17
  - Run pre-defined or custom queries
18
  - Gain actionable insights
19
  """
20
  )
21
 
 
22
  st.sidebar.markdown("### Select Data Source:")
23
  data_source = st.sidebar.radio("", ('SQL Database', 'CSV Database'))
24
 
 
25
  st.sidebar.markdown("### Common Queries:")
26
  common_queries = {
27
  'SQL Database': [
 
41
  }
42
 
43
  for idx, query in enumerate(common_queries[data_source]):
44
+ if st.sidebar.button(query, key=f"query_button_{idx}"):
45
  st.session_state["common_query"] = query
46
 
47
+ st.title("πŸ“ž CallDataAI - Wandsworth Council NetCall Analysis")
 
 
 
 
48
  st.markdown(
49
  """
50
+ **CallDataAI** is an AI-powered chatbot designed for analyzing Wandsworth Council's NetCall data.
51
  Input natural language queries to explore the data and gain actionable insights.
52
  """
53
  )
54
 
 
55
  with st.container():
56
  st.markdown("### Enter Your Question")
57
  question = st.text_input(
 
59
  )
60
  submit = st.button("Submit", type="primary")
61
 
 
62
  if submit:
 
63
  with st.spinner("Processing your request..."):
64
  response = requests.post(
65
  "http://localhost:8000/query", json={"question": question, "data_source": data_source}
66
  )
67
 
 
68
  if response.status_code == 200:
69
  data = response.json()
70
 
 
71
  if "error" in data:
72
  with st.expander("Error Explanation"):
73
  st.error(data["explanation"])
74
 
 
75
  else:
76
  col1, col2 = st.columns(2)
77
 
 
103
  st.markdown("### Available CSV Columns")
104
  st.write(data["columns"])
105
 
 
106
  if "chat_history" not in st.session_state:
107
  st.session_state["chat_history"] = []
108
 
 
112
  else:
113
  st.error(f"Error processing your request: {response.text}")
114
 
 
115
  with st.container():
116
  st.markdown("### Chat History")
117
  if "chat_history" in st.session_state: