Yannael_LB
Initial commit
f39fdae
raw
history blame
7.14 kB
# -*- coding: utf-8 -*-
"""OpenAI_Assistant_WanderLust.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1DLYjk07uQwF_Oqav1NtqMMCz_S2f0q8B
# OpenAI assistant Wanderlust
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API.
This space is inspired by the implementation of Fanilo Andrianasolo using Streamlit - https://www.youtube.com/watch?v=tLeqCDKgEDU
"""
#!pip install -q -U gradio openai datasets
import gradio as gr
import random
import openai
import os
import json
import plotly.graph_objects as go
import time
#os.environ["OPENAI_API_KEY"] = "..." # Replace with your key
#######################################
# TOOLS SETUP
#######################################
def update_map_state(latitude, longitude, zoom):
"""OpenAI tool to update map in-app
"""
session_state[map_state] = {
"latitude": latitude,
"longitude": longitude,
"zoom": zoom,
}
print(session_state[map_state])
return "Map updated"
def add_markers_state(latitudes, longitudes, labels):
"""OpenAI tool to update markers in-app
"""
session_state[markers_state] = {
"lat": latitudes,
"lon": longitudes,
"text": labels,
}
return "Markers added"
tool_to_function = {
"update_map": update_map_state,
"add_markers": add_markers_state,
}
## Helpers
def get_assistant_id():
return session_state[assistant_state].id
def get_thread_id():
return session_state[thread_state].id
def get_run_id():
return session_state[last_openai_run_state].id
def submit_message(assistant_id, thread_id, user_message):
client.beta.threads.messages.create(
thread_id=thread_id, role="user", content=user_message
)
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
)
return run
def get_run_info(run_id, thread_id):
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run_id,
)
return run
#######################################
# SESSION SETUP
#######################################
client = openai.OpenAI()
assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7"
session_state = {}
assistant_state = "assistant"
thread_state = "thread"
conversation_state = "conversation"
last_openai_run_state = "last_openai_run"
map_state = "map"
markers_state = "markers"
if (assistant_state not in session_state) or (thread_state not in session_state):
session_state[assistant_state] = client.beta.assistants.retrieve(assistant_id)
session_state[thread_state] = client.beta.threads.create()
if conversation_state not in session_state:
session_state[conversation_state] = []
if last_openai_run_state not in session_state:
session_state[last_openai_run_state] = None
if map_state not in session_state:
session_state[map_state] = {
"latitude": 48.85,
"longitude": 2.35,
"zoom": 12,
}
if markers_state not in session_state:
session_state[markers_state] = {
"lat": [],
"lon": [],
"text": [],
}
fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
center=go.layout.mapbox.Center(
lat=session_state[map_state]["latitude"],
lon=session_state[map_state]["longitude"]
),
zoom=session_state[map_state]["zoom"]
),
)
def respond(message, chat_history):
print(chat_history)
run = submit_message(get_assistant_id(), get_thread_id(), message)
session_state[last_openai_run_state] = run
print(run)
completed = False
# Polling
while not completed:
run = get_run_info(get_run_id(), get_thread_id())
if run.status == "requires_action":
tools_output = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
f = tool_call.function
f_name = f.name
f_args = json.loads(f.arguments)
print(f"Launching function {f_name} with args {f_args}")
tool_result = tool_to_function[f_name](**f_args)
tools_output.append(
{
"tool_call_id": tool_call.id,
"output": tool_result,
}
)
print(f"Will submit {tools_output}")
client.beta.threads.runs.submit_tool_outputs(
thread_id=get_thread_id(),
run_id=get_run_id(),
tool_outputs=tools_output,
)
if run.status == "completed":
completed = True
else:
time.sleep(0.1)
session_state[conversation_state] = [
[m.role, m.content[0].text.value]
for m in client.beta.threads.messages.list(get_thread_id(), order="asc").data
]
dialog = session_state[conversation_state]
formatted_dialog = []
for i in range(int(len(dialog)/2)):
formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]])
chat_history = formatted_dialog
fig = None
if session_state[markers_state] is None:
fig = go.Figure(go.Scattermapbox())
else :
fig = go.Figure(go.Scattermapbox(
customdata=session_state[markers_state]["text"],
lat=session_state[markers_state]["lat"],
lon=session_state[markers_state]["lon"],
mode='markers',
marker=go.scattermapbox.Marker(
size=18
),
hoverinfo="text",
hovertemplate='<b>Name</b>: %{customdata}'
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
center=go.layout.mapbox.Center(
lat=session_state[map_state]["latitude"],
lon=session_state[map_state]["longitude"]
),
zoom=12
),
)
return "", chat_history, fig
with gr.Blocks(title="OpenAI assistant Wanderlust") as demo:
gr.Markdown("# OpenAI assistant Wanderlust")
with gr.Column():
with gr.Row():
chatbot = gr.Chatbot()
map = gr.Plot(fig)
msg = gr.Textbox("Move the map to Brussels and add markers for five major attractions")
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.ClearButton([msg, chatbot])
msg.submit(respond, [msg, chatbot], [msg, chatbot, map])
submit.click(respond, [msg, chatbot], [msg, chatbot, map])
gr.Markdown(
"""
# Description
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. [Github repository](https://github.com/Yannael/openai-assistant-wanderlust)
This space is inspired by the implementation of [Fanilo Andrianasolo using Streamlit](https://www.youtube.com/watch?v=tLeqCDKgEDU)
"""
)
demo.launch(debug=True)