Yannael_LB commited on
Commit
1c34aa2
·
1 Parent(s): 7c7d5bd

Handling of session states

Browse files
Files changed (1) hide show
  1. app.py +57 -83
app.py CHANGED
@@ -23,50 +23,47 @@ import json
23
  import plotly.graph_objects as go
24
  import time
25
 
26
- #os.environ["OPENAI_API_KEY"] = "..." # Replace with your key
 
 
 
 
27
 
28
  #######################################
29
  # TOOLS SETUP
30
  #######################################
31
 
32
- def update_map_state(latitude, longitude, zoom):
33
  """OpenAI tool to update map in-app
34
  """
35
- session_state[map_state] = {
36
- "latitude": latitude,
37
- "longitude": longitude,
38
- "zoom": zoom,
39
- }
40
- print(session_state[map_state])
 
41
  return "Map updated"
42
 
43
- def add_markers_state(latitudes, longitudes, labels):
44
  """OpenAI tool to update markers in-app
45
  """
46
- session_state[markers_state] = {
47
- "lat": latitudes,
48
- "lon": longitudes,
49
- "text": labels,
50
- }
 
 
51
  return "Markers added"
52
 
53
  tool_to_function = {
54
  "update_map": update_map_state,
55
- "add_markers": add_markers_state,
56
  }
57
 
58
  ## Helpers
59
 
60
- def get_assistant_id():
61
- return session_state[assistant_state].id
62
-
63
- def get_thread_id():
64
- return session_state[thread_state].id
65
-
66
-
67
- def get_run_id():
68
- return session_state[last_openai_run_state].id
69
-
70
  def submit_message(assistant_id, thread_id, user_message):
71
  client.beta.threads.messages.create(
72
  thread_id=thread_id, role="user", content=user_message
@@ -85,43 +82,17 @@ def get_run_info(run_id, thread_id):
85
  return run
86
 
87
  #######################################
88
- # SESSION SETUP
89
  #######################################
90
 
91
- client = openai.OpenAI()
92
- assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7"
93
-
94
- session_state = {}
95
-
96
- assistant_state = "assistant"
97
- thread_state = "thread"
98
- conversation_state = "conversation"
99
- last_openai_run_state = "last_openai_run"
100
- map_state = "map"
101
- markers_state = "markers"
102
-
103
- if (assistant_state not in session_state) or (thread_state not in session_state):
104
- session_state[assistant_state] = client.beta.assistants.retrieve(assistant_id)
105
- session_state[thread_state] = client.beta.threads.create()
106
-
107
- if conversation_state not in session_state:
108
- session_state[conversation_state] = []
109
-
110
- if last_openai_run_state not in session_state:
111
- session_state[last_openai_run_state] = None
112
-
113
- if map_state not in session_state:
114
- session_state[map_state] = {
115
  "latitude": 48.85,
116
  "longitude": 2.35,
117
  "zoom": 12,
118
- }
119
-
120
- if markers_state not in session_state:
121
- session_state[markers_state] = {
122
  "lat": [],
123
  "lon": [],
124
  "text": [],
 
125
  }
126
 
127
  fig = go.Figure(go.Scattermapbox())
@@ -131,29 +102,28 @@ fig.update_layout(
131
  hovermode='closest',
132
  mapbox=dict(
133
  center=go.layout.mapbox.Center(
134
- lat=session_state[map_state]["latitude"],
135
- lon=session_state[map_state]["longitude"]
136
  ),
137
- zoom=session_state[map_state]["zoom"]
138
  ),
139
  )
140
 
141
- def respond(message, chat_history):
142
 
143
- print(chat_history)
 
144
 
145
- run = submit_message(get_assistant_id(), get_thread_id(), message)
146
 
147
- session_state[last_openai_run_state] = run
148
-
149
- print(run)
150
 
151
  completed = False
152
 
153
  # Polling
154
  while not completed:
155
 
156
- run = get_run_info(get_run_id(), get_thread_id())
157
 
158
  if run.status == "requires_action":
159
 
@@ -165,9 +135,9 @@ def respond(message, chat_history):
165
  f_name = f.name
166
  f_args = json.loads(f.arguments)
167
 
168
- print(f"Launching function {f_name} with args {f_args}")
169
 
170
- tool_result = tool_to_function[f_name](**f_args)
171
 
172
  tools_output.append(
173
  {
@@ -176,11 +146,11 @@ def respond(message, chat_history):
176
  }
177
  )
178
 
179
- print(f"Will submit {tools_output}")
180
 
181
  client.beta.threads.runs.submit_tool_outputs(
182
- thread_id=get_thread_id(),
183
- run_id=get_run_id(),
184
  tool_outputs=tools_output,
185
  )
186
 
@@ -191,12 +161,11 @@ def respond(message, chat_history):
191
  else:
192
  time.sleep(0.1)
193
 
194
- session_state[conversation_state] = [
195
  [m.role, m.content[0].text.value]
196
- for m in client.beta.threads.messages.list(get_thread_id(), order="asc").data
197
  ]
198
 
199
- dialog = session_state[conversation_state]
200
  formatted_dialog = []
201
  for i in range(int(len(dialog)/2)):
202
  formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]])
@@ -204,17 +173,18 @@ def respond(message, chat_history):
204
 
205
  chat_history = formatted_dialog
206
 
207
- fig = None
208
 
209
- if session_state[markers_state] is None:
210
 
 
211
  fig = go.Figure(go.Scattermapbox())
212
 
213
  else :
214
  fig = go.Figure(go.Scattermapbox(
215
- customdata=session_state[markers_state]["text"],
216
- lat=session_state[markers_state]["lat"],
217
- lon=session_state[markers_state]["lon"],
218
  mode='markers',
219
  marker=go.scattermapbox.Marker(
220
  size=18
@@ -228,34 +198,38 @@ def respond(message, chat_history):
228
  hovermode='closest',
229
  mapbox=dict(
230
  center=go.layout.mapbox.Center(
231
- lat=session_state[map_state]["latitude"],
232
- lon=session_state[map_state]["longitude"]
233
  ),
234
  zoom=12
235
  ),
236
  )
237
 
238
- return "", chat_history, fig
239
 
240
  with gr.Blocks(title="OpenAI assistant Wanderlust") as demo:
241
 
242
  gr.Markdown("# OpenAI assistant Wanderlust")
243
 
 
 
 
 
244
  with gr.Column():
245
  with gr.Row():
246
 
247
  chatbot = gr.Chatbot()
248
  map = gr.Plot(fig)
249
 
250
- msg = gr.Textbox("Move the map to Brussels and locate the best places for waffles")
251
 
252
  with gr.Column():
253
  with gr.Row():
254
  submit = gr.Button("Submit")
255
  clear = gr.ClearButton([msg, chatbot])
256
 
257
- msg.submit(respond, [msg, chatbot], [msg, chatbot, map])
258
- submit.click(respond, [msg, chatbot], [msg, chatbot, map])
259
 
260
  gr.Markdown(
261
  """
 
23
  import plotly.graph_objects as go
24
  import time
25
 
26
+ os.environ["OPENAI_API_KEY"] = "sk-..." # Replace with your key
27
+ assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7"
28
+
29
+ client = openai.OpenAI()
30
+ assistant = client.beta.assistants.retrieve(assistant_id)
31
 
32
  #######################################
33
  # TOOLS SETUP
34
  #######################################
35
 
36
+ def update_map_state(map_metadata_state, latitude, longitude, zoom):
37
  """OpenAI tool to update map in-app
38
  """
39
+
40
+ map_metadata_state["latitude"] = latitude
41
+ map_metadata_state["longitude"] = longitude
42
+ map_metadata_state["zoom"] = zoom
43
+
44
+ print(map_metadata_state)
45
+
46
  return "Map updated"
47
 
48
+ def add_markers_state(map_metadata_state, latitudes, longitudes, labels):
49
  """OpenAI tool to update markers in-app
50
  """
51
+
52
+ map_metadata_state["lat"] = latitudes
53
+ map_metadata_state["lon"] = longitudes
54
+ map_metadata_state["text"] = labels
55
+
56
+ print(map_metadata_state)
57
+
58
  return "Markers added"
59
 
60
  tool_to_function = {
61
  "update_map": update_map_state,
62
+ "add_markers": add_markers_state
63
  }
64
 
65
  ## Helpers
66
 
 
 
 
 
 
 
 
 
 
 
67
  def submit_message(assistant_id, thread_id, user_message):
68
  client.beta.threads.messages.create(
69
  thread_id=thread_id, role="user", content=user_message
 
82
  return run
83
 
84
  #######################################
85
+ # INITIAL DATA FOR MAP
86
  #######################################
87
 
88
+ map_metadata = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  "latitude": 48.85,
90
  "longitude": 2.35,
91
  "zoom": 12,
 
 
 
 
92
  "lat": [],
93
  "lon": [],
94
  "text": [],
95
+
96
  }
97
 
98
  fig = go.Figure(go.Scattermapbox())
 
102
  hovermode='closest',
103
  mapbox=dict(
104
  center=go.layout.mapbox.Center(
105
+ lat=map_metadata["latitude"],
106
+ lon=map_metadata["longitude"]
107
  ),
108
+ zoom=map_metadata["zoom"]
109
  ),
110
  )
111
 
112
+ def respond(message, chat_history, thread, map_metadata_state):
113
 
114
+ if thread is None:
115
+ thread = client.beta.threads.create()
116
 
117
+ print(map_metadata_state)
118
 
119
+ run = submit_message(assistant.id, thread.id, message)
 
 
120
 
121
  completed = False
122
 
123
  # Polling
124
  while not completed:
125
 
126
+ run = get_run_info(run.id, thread.id)
127
 
128
  if run.status == "requires_action":
129
 
 
135
  f_name = f.name
136
  f_args = json.loads(f.arguments)
137
 
138
+ #print(f"Launching function {f_name} with args {f_args}")
139
 
140
+ tool_result = tool_to_function[f_name](map_metadata_state,**f_args)
141
 
142
  tools_output.append(
143
  {
 
146
  }
147
  )
148
 
149
+ #print(f"Will submit {tools_output}")
150
 
151
  client.beta.threads.runs.submit_tool_outputs(
152
+ thread_id=thread.id,
153
+ run_id=run.id,
154
  tool_outputs=tools_output,
155
  )
156
 
 
161
  else:
162
  time.sleep(0.1)
163
 
164
+ dialog = [
165
  [m.role, m.content[0].text.value]
166
+ for m in client.beta.threads.messages.list(thread.id, order="asc").data
167
  ]
168
 
 
169
  formatted_dialog = []
170
  for i in range(int(len(dialog)/2)):
171
  formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]])
 
173
 
174
  chat_history = formatted_dialog
175
 
176
+ print(formatted_dialog)
177
 
178
+ fig = None
179
 
180
+ if len(map_metadata_state["lat"])==0:
181
  fig = go.Figure(go.Scattermapbox())
182
 
183
  else :
184
  fig = go.Figure(go.Scattermapbox(
185
+ customdata=map_metadata_state["text"],
186
+ lat=map_metadata_state["lat"],
187
+ lon=map_metadata_state["lon"],
188
  mode='markers',
189
  marker=go.scattermapbox.Marker(
190
  size=18
 
198
  hovermode='closest',
199
  mapbox=dict(
200
  center=go.layout.mapbox.Center(
201
+ lat=map_metadata_state["latitude"],
202
+ lon=map_metadata_state["longitude"]
203
  ),
204
  zoom=12
205
  ),
206
  )
207
 
208
+ return "", chat_history, fig, thread, map_metadata_state
209
 
210
  with gr.Blocks(title="OpenAI assistant Wanderlust") as demo:
211
 
212
  gr.Markdown("# OpenAI assistant Wanderlust")
213
 
214
+ thread = gr.State()
215
+
216
+ map_metadata_state = gr.State(map_metadata)
217
+
218
  with gr.Column():
219
  with gr.Row():
220
 
221
  chatbot = gr.Chatbot()
222
  map = gr.Plot(fig)
223
 
224
+ msg = gr.Textbox("Move the map to Brussels and add markers on shops with the best waffles")
225
 
226
  with gr.Column():
227
  with gr.Row():
228
  submit = gr.Button("Submit")
229
  clear = gr.ClearButton([msg, chatbot])
230
 
231
+ msg.submit(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state])
232
+ submit.click(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state])
233
 
234
  gr.Markdown(
235
  """