zenn / app.py
nakamura196's picture
feat: initial commit
86de97d
raw
history blame
4.29 kB
import gradio as gr
from openai import AzureOpenAI
import os
from dotenv import load_dotenv
import time
def load_environment():
"""Load environment variables."""
load_dotenv(override=True)
def initialize_openai_client():
"""Initialize the Azure OpenAI client."""
return AzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-10-01-preview"
)
def create_assistant(client, vector_store_id):
"""Create an assistant with specified configuration."""
return client.beta.assistants.create(
model="gpt-4o",
instructions="ζŒ‡η€ΊγŒγͺγ„ι™γ‚Šγ€ζ—₯本θͺžγ§ε›žη­”してください。",
tools=[{
"type": "file_search",
"file_search": {"ranking_options": {"ranker": "default_2024_08_21", "score_threshold": 0}}
}],
tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}},
temperature=0
)
def create_thread(client):
"""Create a new thread."""
return client.beta.threads.create()
def clear_thread(_):
"""Clear the chat history and reset the thread."""
global thread
thread = create_thread(client)
return [], ""
def get_annotations(msg):
annotations = msg.content[0].text.annotations
file_ids = []
if annotations:
for annotation in annotations:
file_id = annotation.file_citation.file_id
if file_id in file_ids:
continue
print("file_id", file_id)
cited_file = client.files.retrieve(file_id)
print("filename", cited_file.filename)
try:
content = client.files.content(file_id)
except Exception as e:
print(e)
pass
file_ids.append(file_id)
return file_ids
def get_chatbot_response(client, thread_id, assistant_id, message):
"""Get chatbot response for a given message."""
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=message # Ensure the content is an object with a `text` key
)
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id
)
while run.status in ["queued", "in_progress", "cancelling"]:
time.sleep(1)
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run.id
)
if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread_id)
for msg in messages:
# file_ids = get_annotations(msg)
main_text = msg.content[0].text.value
# main_text += "\n> aaa"
return main_text
elif run.status == "requires_action":
# Handle cases where the assistant requires further action
pass
return "Unable to retrieve a response." # Fallback response
def chatbot_response(history, message):
"""Wrapper function to generate chatbot response."""
global thread
# Get response from the API
assistant_response = get_chatbot_response(client, thread.id, assistant.id, message)
# Update chat history
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": assistant_response})
return history, ""
# Load environment variables
load_environment()
# Initialize OpenAI client
client = initialize_openai_client()
# Define vector store ID
vector_store_id = os.getenv("AZURE_OPENAI_VECTOR_STORE_ID")
# Create assistant and thread
assistant = create_assistant(client, vector_store_id)
thread = create_thread(client)
with gr.Blocks() as demo:
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
bot_message = get_chatbot_response(client, thread.id, assistant.id, message)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": bot_message})
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(clear_thread, [chatbot])
if __name__ == "__main__":
demo.launch()