Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # OpenAI assistant Wanderlust | |
| Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. | |
| This space is inspired by the implementation of Fanilo Andrianasolo using Streamlit - https://www.youtube.com/watch?v=tLeqCDKgEDU | |
| """ | |
| #!pip install -q -U gradio openai datasets | |
| import gradio as gr | |
| import random | |
| import openai | |
| import os | |
| import json | |
| import plotly.graph_objects as go | |
| import time | |
| #os.environ["OPENAI_API_KEY"] = "sk-..." # Replace with your key | |
| assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7" | |
| client = openai.OpenAI() | |
| assistant = client.beta.assistants.retrieve(assistant_id) | |
| ####################################### | |
| # TOOLS SETUP | |
| ####################################### | |
| def update_map_state(map_metadata_state, latitude, longitude, zoom): | |
| """OpenAI tool to update map in-app | |
| """ | |
| map_metadata_state["latitude"] = latitude | |
| map_metadata_state["longitude"] = longitude | |
| map_metadata_state["zoom"] = zoom | |
| print(map_metadata_state) | |
| return "Map updated" | |
| def add_markers_state(map_metadata_state, latitudes, longitudes, labels): | |
| """OpenAI tool to update markers in-app | |
| """ | |
| map_metadata_state["lat"] = latitudes | |
| map_metadata_state["lon"] = longitudes | |
| map_metadata_state["text"] = labels | |
| print(map_metadata_state) | |
| return "Markers added" | |
| tool_to_function = { | |
| "update_map": update_map_state, | |
| "add_markers": add_markers_state | |
| } | |
| ## Helpers | |
| def submit_message(assistant_id, thread_id, user_message): | |
| client.beta.threads.messages.create( | |
| thread_id=thread_id, role="user", content=user_message | |
| ) | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| ) | |
| return run | |
| def get_run_info(run_id, thread_id): | |
| run = client.beta.threads.runs.retrieve( | |
| thread_id=thread_id, | |
| run_id=run_id, | |
| ) | |
| return run | |
| ####################################### | |
| # INITIAL DATA FOR MAP | |
| ####################################### | |
| map_metadata = { | |
| "latitude": 48.85, | |
| "longitude": 2.35, | |
| "zoom": 12, | |
| "lat": [], | |
| "lon": [], | |
| "text": [], | |
| } | |
| fig = go.Figure(go.Scattermapbox()) | |
| fig.update_layout( | |
| mapbox_style="open-street-map", | |
| hovermode='closest', | |
| mapbox=dict( | |
| center=go.layout.mapbox.Center( | |
| lat=map_metadata["latitude"], | |
| lon=map_metadata["longitude"] | |
| ), | |
| zoom=map_metadata["zoom"] | |
| ), | |
| ) | |
| def respond(message, chat_history, thread, map_metadata_state): | |
| if thread is None: | |
| thread = client.beta.threads.create() | |
| print(map_metadata_state) | |
| run = submit_message(assistant.id, thread.id, message) | |
| completed = False | |
| # Polling | |
| while not completed: | |
| run = get_run_info(run.id, thread.id) | |
| if run.status == "requires_action": | |
| tools_output = [] | |
| for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
| f = tool_call.function | |
| f_name = f.name | |
| f_args = json.loads(f.arguments) | |
| #print(f"Launching function {f_name} with args {f_args}") | |
| tool_result = tool_to_function[f_name](map_metadata_state,**f_args) | |
| tools_output.append( | |
| { | |
| "tool_call_id": tool_call.id, | |
| "output": tool_result, | |
| } | |
| ) | |
| #print(f"Will submit {tools_output}") | |
| client.beta.threads.runs.submit_tool_outputs( | |
| thread_id=thread.id, | |
| run_id=run.id, | |
| tool_outputs=tools_output, | |
| ) | |
| if run.status == "completed": | |
| completed = True | |
| else: | |
| time.sleep(0.1) | |
| dialog = [ | |
| [m.role, m.content[0].text.value] | |
| for m in client.beta.threads.messages.list(thread.id, order="asc").data | |
| ] | |
| formatted_dialog = [] | |
| for i in range(int(len(dialog)/2)): | |
| formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]]) | |
| chat_history = formatted_dialog | |
| print(formatted_dialog) | |
| fig = None | |
| if len(map_metadata_state["lat"])==0: | |
| fig = go.Figure(go.Scattermapbox()) | |
| else : | |
| fig = go.Figure(go.Scattermapbox( | |
| customdata=map_metadata_state["text"], | |
| lat=map_metadata_state["lat"], | |
| lon=map_metadata_state["lon"], | |
| mode='markers', | |
| marker=go.scattermapbox.Marker( | |
| size=18 | |
| ), | |
| hoverinfo="text", | |
| hovertemplate='<b>Name</b>: %{customdata}' | |
| )) | |
| fig.update_layout( | |
| mapbox_style="open-street-map", | |
| hovermode='closest', | |
| mapbox=dict( | |
| center=go.layout.mapbox.Center( | |
| lat=map_metadata_state["latitude"], | |
| lon=map_metadata_state["longitude"] | |
| ), | |
| zoom=12 | |
| ), | |
| ) | |
| return "", chat_history, fig, thread, map_metadata_state | |
| with gr.Blocks(title="OpenAI assistant Wanderlust") as demo: | |
| gr.Markdown("# OpenAI assistant Wanderlust") | |
| thread = gr.State() | |
| map_metadata_state = gr.State(map_metadata) | |
| with gr.Column(): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot() | |
| map = gr.Plot(fig) | |
| msg = gr.Textbox("Move the map to Brussels and add markers on shops with the best waffles") | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| clear = gr.ClearButton([msg, chatbot]) | |
| msg.submit(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state]) | |
| submit.click(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state]) | |
| gr.Markdown( | |
| """ | |
| # Description | |
| Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. [Github repository](https://github.com/Yannael/openai-assistant-wanderlust) | |
| This space is inspired by the implementation of [Fanilo Andrianasolo using Streamlit](https://www.youtube.com/watch?v=tLeqCDKgEDU) | |
| """ | |
| ) | |
| demo.queue() | |
| demo.launch(debug=True) | |