Spaces:
Sleeping
Sleeping
| import pdb | |
| import gradio as gr | |
| import logfire | |
| from custom_retriever import CustomRetriever | |
| from llama_index.agent.openai import OpenAIAgent | |
| from llama_index.core.llms import MessageRole | |
| from llama_index.core.memory import ChatSummaryMemoryBuffer | |
| from llama_index.core.tools import RetrieverTool, ToolMetadata | |
| from llama_index.core.vector_stores import ( | |
| FilterCondition, | |
| FilterOperator, | |
| MetadataFilter, | |
| MetadataFilters, | |
| ) | |
| from llama_index.llms.openai import OpenAI | |
| from prompts import system_message_openai_agent | |
| from setup import ( | |
| AVAILABLE_SOURCES, | |
| AVAILABLE_SOURCES_UI, | |
| CONCURRENCY_COUNT, | |
| custom_retriever_all_sources, | |
| ) | |
| def update_query_engine_tools(selected_sources) -> list[RetrieverTool]: | |
| tools = [] | |
| source_mapping: dict[str, tuple[CustomRetriever, str, str]] = { | |
| "All Sources": ( | |
| custom_retriever_all_sources, | |
| "all_sources_info", | |
| """Useful tool that contains general information about the field of AI.""", | |
| ), | |
| } | |
| for source in selected_sources: | |
| if source in source_mapping: | |
| custom_retriever, name, description = source_mapping[source] | |
| tools.append( | |
| RetrieverTool( | |
| retriever=custom_retriever, | |
| metadata=ToolMetadata( | |
| name=name, | |
| description=description, | |
| ), | |
| ) | |
| ) | |
| return tools | |
| def generate_completion( | |
| query, | |
| history, | |
| sources, | |
| model, | |
| memory, | |
| ): | |
| llm = OpenAI(temperature=1, model=model, max_tokens=None) | |
| client = llm._get_client() | |
| logfire.instrument_openai(client) | |
| with logfire.span(f"Running query: {query}"): | |
| logfire.info(f"User chosen sources: {sources}") | |
| memory_chat_list = memory.get() | |
| if len(memory_chat_list) != 0: | |
| user_index_memory = [ | |
| i | |
| for i, msg in enumerate(memory_chat_list) | |
| if msg.role == MessageRole.USER | |
| ] | |
| user_index_history = [ | |
| i for i, msg in enumerate(history) if msg["role"] == "user" | |
| ] | |
| if len(user_index_memory) > len(user_index_history): | |
| logfire.warn(f"There are more user messages in memory than in history") | |
| user_index_to_remove = user_index_memory[len(user_index_history)] | |
| memory_chat_list = memory_chat_list[:user_index_to_remove] | |
| memory.set(memory_chat_list) | |
| logfire.info(f"chat_history: {len(memory.get())} {memory.get()}") | |
| logfire.info(f"gradio_history: {len(history)} {history}") | |
| query_engine_tools: list[RetrieverTool] = update_query_engine_tools( | |
| ["All Sources"] | |
| ) | |
| filter_list = [] | |
| source_mapping = { | |
| "Transformers Docs": "transformers", | |
| "PEFT Docs": "peft", | |
| "TRL Docs": "trl", | |
| "LlamaIndex Docs": "llama_index", | |
| "LangChain Docs": "langchain", | |
| "OpenAI Cookbooks": "openai_cookbooks", | |
| "Towards AI Blog": "tai_blog", | |
| "8 Hour Primer": "8-hour_primer", | |
| "Advanced LLM Developer": "llm_developer", | |
| "Python Primer": "python_primer", | |
| } | |
| for source in sources: | |
| if source in source_mapping: | |
| filter_list.append( | |
| MetadataFilter( | |
| key="source", | |
| operator=FilterOperator.EQ, | |
| value=source_mapping[source], | |
| ) | |
| ) | |
| filters = MetadataFilters( | |
| filters=filter_list, | |
| condition=FilterCondition.OR, | |
| ) | |
| logfire.info(f"Filters: {filters}") | |
| query_engine_tools[0].retriever._vector_retriever._filters = filters | |
| # pdb.set_trace() | |
| agent = OpenAIAgent.from_tools( | |
| llm=llm, | |
| memory=memory, | |
| tools=query_engine_tools, | |
| system_prompt=system_message_openai_agent, | |
| ) | |
| completion = agent.stream_chat(query) | |
| answer_str = "" | |
| for token in completion.response_gen: | |
| answer_str += token | |
| yield answer_str | |
| for answer_str in add_sources(answer_str, completion): | |
| yield answer_str | |
| def add_sources(answer_str, completion): | |
| if completion is None: | |
| yield answer_str | |
| formatted_sources = format_sources(completion) | |
| if formatted_sources == "": | |
| yield answer_str | |
| if formatted_sources != "": | |
| answer_str += "\n\n" + formatted_sources | |
| yield answer_str | |
| def format_sources(completion) -> str: | |
| if len(completion.sources) == 0: | |
| return "" | |
| # logfire.info(f"Formatting sources: {completion.sources}") | |
| display_source_to_ui = { | |
| src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI) | |
| } | |
| documents_answer_template: str = ( | |
| "π Here are the sources I used to answer your question:\n{documents}" | |
| ) | |
| document_template: str = "[π {source}: {title}]({url}), relevance: {score:2.2f}" | |
| all_documents = [] | |
| for source in completion.sources: # looping over list[ToolOutput] | |
| if isinstance(source.raw_output, Exception): | |
| logfire.error(f"Error in source output: {source.raw_output}") | |
| # pdb.set_trace() | |
| continue | |
| if not isinstance(source.raw_output, list): | |
| logfire.warn(f"Unexpected source output type: {type(source.raw_output)}") | |
| continue | |
| for src in source.raw_output: # looping over list[NodeWithScore] | |
| document = document_template.format( | |
| title=src.metadata["title"], | |
| score=src.score, | |
| source=display_source_to_ui.get( | |
| src.metadata["source"], src.metadata["source"] | |
| ), | |
| url=src.metadata["url"], | |
| ) | |
| all_documents.append(document) | |
| if len(all_documents) == 0: | |
| return "" | |
| else: | |
| documents = "\n".join(all_documents) | |
| return documents_answer_template.format(documents=documents) | |
| def save_completion(completion, history): | |
| pass | |
| def vote(data: gr.LikeData): | |
| pass | |
| accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False) | |
| sources = gr.CheckboxGroup( | |
| AVAILABLE_SOURCES_UI, | |
| label="Sources", | |
| value=[ | |
| "Advanced LLM Developer", | |
| "8 Hour Primer", | |
| "Python Primer", | |
| "Towards AI Blog", | |
| "Transformers Docs", | |
| "PEFT Docs", | |
| "TRL Docs", | |
| "LlamaIndex Docs", | |
| "LangChain Docs", | |
| "OpenAI Cookbooks", | |
| ], | |
| interactive=True, | |
| ) | |
| model = gr.Dropdown( | |
| [ | |
| "gpt-4o-mini", | |
| ], | |
| label="Model", | |
| value="gpt-4o-mini", | |
| interactive=False, | |
| ) | |
| with gr.Blocks( | |
| title="Towards AI π€", | |
| analytics_enabled=True, | |
| fill_height=True, | |
| ) as demo: | |
| memory = gr.State( | |
| lambda: ChatSummaryMemoryBuffer.from_defaults( | |
| token_limit=120000, | |
| ) | |
| ) | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| scale=20, | |
| placeholder="<strong>Towards AI π€: A Question-Answering Bot for anything AI-related</strong><br>", | |
| show_label=False, | |
| show_copy_button=True, | |
| ) | |
| chatbot.like(vote, None, None) | |
| gr.ChatInterface( | |
| fn=generate_completion, | |
| type="messages", | |
| chatbot=chatbot, | |
| additional_inputs=[sources, model, memory], | |
| additional_inputs_accordion=accordion, | |
| # fill_height=True, | |
| # fill_width=True, | |
| analytics_enabled=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=CONCURRENCY_COUNT) | |
| demo.launch(debug=False, share=False) | |