import gradio as gr from openai import AzureOpenAI import os from dotenv import load_dotenv import time def load_environment(): """Load environment variables.""" load_dotenv(override=True) def initialize_openai_client(): """Initialize the Azure OpenAI client.""" return AzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version="2024-10-01-preview" ) def create_assistant(client, vector_store_id): """Create an assistant with specified configuration.""" return client.beta.assistants.create( model="gpt-4o", instructions="指示がない限り、日本語で回答してください。", tools=[{ "type": "file_search", "file_search": {"ranking_options": {"ranker": "default_2024_08_21", "score_threshold": 0}} }], tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}}, temperature=0 ) def create_thread(): """Create a new thread.""" return client.beta.threads.create() def clear_thread(state): """セッションをリセットし、チャット履歴をクリアする。""" state = initialize_session() # 新しいスレッドを生成 return [], "" def get_annotations(msg): annotations = msg.content[0].text.annotations file_ids = [] if annotations: for annotation in annotations: file_id = annotation.file_citation.file_id if file_id in file_ids: continue print("file_id", file_id) cited_file = client.files.retrieve(file_id) print("filename", cited_file.filename) try: content = client.files.content(file_id) except Exception as e: print(e) pass file_ids.append(file_id) return file_ids def get_chatbot_response(client, thread_id, assistant_id, message): """Get chatbot response for a given message.""" client.beta.threads.messages.create( thread_id=thread_id, role="user", content=message # Ensure the content is an object with a `text` key ) run = client.beta.threads.runs.create( thread_id=thread_id, assistant_id=assistant_id ) while run.status in ["queued", "in_progress", "cancelling"]: time.sleep(1) run = client.beta.threads.runs.retrieve( thread_id=thread_id, run_id=run.id ) if run.status == "completed": messages = client.beta.threads.messages.list(thread_id=thread_id) for msg in messages: # file_ids = get_annotations(msg) main_text = msg.content[0].text.value # main_text += "\n> aaa" return main_text elif run.status == "requires_action": # Handle cases where the assistant requires further action pass return "Unable to retrieve a response." # Fallback response def chatbot_response(history, message): """Wrapper function to generate chatbot response.""" global thread # Get response from the API assistant_response = get_chatbot_response(client, thread.id, assistant.id, message) # Update chat history history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": assistant_response}) return history, "" # Load environment variables load_environment() client = initialize_openai_client() vector_store_id = os.getenv("AZURE_OPENAI_VECTOR_STORE_ID") assistant = create_assistant(client, vector_store_id) def respond(message, chat_history, state): """チャット履歴と状態を更新する。""" thread_id = state["thread_id"] bot_message = get_chatbot_response(client, thread_id, assistant.id, message) chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": bot_message}) return "", chat_history def initialize_session(): """セッションごとに独立したスレッドを初期化する。""" thread = create_thread() return {"thread_id": thread.id} with gr.Blocks() as demo: gr.Markdown(""" # Azure OpenAI Assistants API x Gradio x Zenn This is a Gradio demo of Retrieval-Augmented Generation (RAG) using the Azure OpenAI Assistants API, applied to [Zenn articles](https://zenn.dev/nakamura196). """) chatbot = gr.Chatbot(type="messages") msg = gr.Textbox(placeholder="ここにメッセージを入力してください...") state = gr.State(initialize_session) # セッションごとの状態を初期化 clear = gr.Button("Clear") msg.submit(respond, [msg, chatbot, state], [msg, chatbot]) clear.click(clear_thread, inputs=[state], outputs=[chatbot, msg]) if __name__ == "__main__": demo.launch()