Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""OpenAI_Assistant_WanderLust.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1DLYjk07uQwF_Oqav1NtqMMCz_S2f0q8B | |
# OpenAI assistant Wanderlust | |
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. | |
This space is inspired by the implementation of Fanilo Andrianasolo using Streamlit - https://www.youtube.com/watch?v=tLeqCDKgEDU | |
""" | |
#!pip install -q -U gradio openai datasets | |
import gradio as gr | |
import random | |
import openai | |
import os | |
import json | |
import plotly.graph_objects as go | |
import time | |
#os.environ["OPENAI_API_KEY"] = "..." # Replace with your key | |
####################################### | |
# TOOLS SETUP | |
####################################### | |
def update_map_state(latitude, longitude, zoom): | |
"""OpenAI tool to update map in-app | |
""" | |
session_state[map_state] = { | |
"latitude": latitude, | |
"longitude": longitude, | |
"zoom": zoom, | |
} | |
print(session_state[map_state]) | |
return "Map updated" | |
def add_markers_state(latitudes, longitudes, labels): | |
"""OpenAI tool to update markers in-app | |
""" | |
session_state[markers_state] = { | |
"lat": latitudes, | |
"lon": longitudes, | |
"text": labels, | |
} | |
return "Markers added" | |
tool_to_function = { | |
"update_map": update_map_state, | |
"add_markers": add_markers_state, | |
} | |
## Helpers | |
def get_assistant_id(): | |
return session_state[assistant_state].id | |
def get_thread_id(): | |
return session_state[thread_state].id | |
def get_run_id(): | |
return session_state[last_openai_run_state].id | |
def submit_message(assistant_id, thread_id, user_message): | |
client.beta.threads.messages.create( | |
thread_id=thread_id, role="user", content=user_message | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
) | |
return run | |
def get_run_info(run_id, thread_id): | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run_id, | |
) | |
return run | |
####################################### | |
# SESSION SETUP | |
####################################### | |
client = openai.OpenAI() | |
assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7" | |
session_state = {} | |
assistant_state = "assistant" | |
thread_state = "thread" | |
conversation_state = "conversation" | |
last_openai_run_state = "last_openai_run" | |
map_state = "map" | |
markers_state = "markers" | |
if (assistant_state not in session_state) or (thread_state not in session_state): | |
session_state[assistant_state] = client.beta.assistants.retrieve(assistant_id) | |
session_state[thread_state] = client.beta.threads.create() | |
if conversation_state not in session_state: | |
session_state[conversation_state] = [] | |
if last_openai_run_state not in session_state: | |
session_state[last_openai_run_state] = None | |
if map_state not in session_state: | |
session_state[map_state] = { | |
"latitude": 48.85, | |
"longitude": 2.35, | |
"zoom": 12, | |
} | |
if markers_state not in session_state: | |
session_state[markers_state] = { | |
"lat": [], | |
"lon": [], | |
"text": [], | |
} | |
fig = go.Figure(go.Scattermapbox()) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
hovermode='closest', | |
mapbox=dict( | |
center=go.layout.mapbox.Center( | |
lat=session_state[map_state]["latitude"], | |
lon=session_state[map_state]["longitude"] | |
), | |
zoom=session_state[map_state]["zoom"] | |
), | |
) | |
def respond(message, chat_history): | |
print(chat_history) | |
run = submit_message(get_assistant_id(), get_thread_id(), message) | |
session_state[last_openai_run_state] = run | |
print(run) | |
completed = False | |
# Polling | |
while not completed: | |
run = get_run_info(get_run_id(), get_thread_id()) | |
if run.status == "requires_action": | |
tools_output = [] | |
for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
f = tool_call.function | |
f_name = f.name | |
f_args = json.loads(f.arguments) | |
print(f"Launching function {f_name} with args {f_args}") | |
tool_result = tool_to_function[f_name](**f_args) | |
tools_output.append( | |
{ | |
"tool_call_id": tool_call.id, | |
"output": tool_result, | |
} | |
) | |
print(f"Will submit {tools_output}") | |
client.beta.threads.runs.submit_tool_outputs( | |
thread_id=get_thread_id(), | |
run_id=get_run_id(), | |
tool_outputs=tools_output, | |
) | |
if run.status == "completed": | |
completed = True | |
else: | |
time.sleep(0.1) | |
session_state[conversation_state] = [ | |
[m.role, m.content[0].text.value] | |
for m in client.beta.threads.messages.list(get_thread_id(), order="asc").data | |
] | |
dialog = session_state[conversation_state] | |
formatted_dialog = [] | |
for i in range(int(len(dialog)/2)): | |
formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]]) | |
chat_history = formatted_dialog | |
fig = None | |
if session_state[markers_state] is None: | |
fig = go.Figure(go.Scattermapbox()) | |
else : | |
fig = go.Figure(go.Scattermapbox( | |
customdata=session_state[markers_state]["text"], | |
lat=session_state[markers_state]["lat"], | |
lon=session_state[markers_state]["lon"], | |
mode='markers', | |
marker=go.scattermapbox.Marker( | |
size=18 | |
), | |
hoverinfo="text", | |
hovertemplate='<b>Name</b>: %{customdata}' | |
)) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
hovermode='closest', | |
mapbox=dict( | |
center=go.layout.mapbox.Center( | |
lat=session_state[map_state]["latitude"], | |
lon=session_state[map_state]["longitude"] | |
), | |
zoom=12 | |
), | |
) | |
return "", chat_history, fig | |
with gr.Blocks(title="OpenAI assistant Wanderlust") as demo: | |
gr.Markdown("# OpenAI assistant Wanderlust") | |
with gr.Column(): | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
map = gr.Plot(fig) | |
msg = gr.Textbox("Move the map to Brussels and add markers for five major attractions") | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
clear = gr.ClearButton([msg, chatbot]) | |
msg.submit(respond, [msg, chatbot], [msg, chatbot, map]) | |
submit.click(respond, [msg, chatbot], [msg, chatbot, map]) | |
gr.Markdown( | |
""" | |
# Description | |
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. [Github repository](https://github.com/Yannael/openai-assistant-wanderlust) | |
This space is inspired by the implementation of [Fanilo Andrianasolo using Streamlit](https://www.youtube.com/watch?v=tLeqCDKgEDU) | |
""" | |
) | |
demo.launch(debug=True) | |