import chainlit as cl from llama_index import ServiceContext from llama_index.node_parser.simple import SimpleNodeParser from llama_index.langchain_helpers.text_splitter import TokenTextSplitter from llama_index.llms import OpenAI from llama_index.embeddings.openai import OpenAIEmbedding from llama_index import VectorStoreIndex from llama_index.vector_stores import ChromaVectorStore from llama_index.storage.storage_context import StorageContext import chromadb from llama_index.readers.wikipedia import WikipediaReader from llama_index.tools import FunctionTool from llama_index.vector_stores.types import ( VectorStoreInfo, MetadataInfo, ExactMatchFilter, MetadataFilters, ) from llama_index.retrievers import VectorIndexRetriever from llama_index.query_engine import RetrieverQueryEngine from typing import List, Tuple, Any from pydantic import BaseModel, Field from llama_index.agent import OpenAIAgent embed_model = OpenAIEmbedding() chunk_size = 1000 llm = OpenAI( temperature=0, model="gpt-3.5-turbo", streaming=True ) service_context = ServiceContext.from_defaults( llm=llm, chunk_size=chunk_size, embed_model=embed_model ) text_splitter = TokenTextSplitter( chunk_size=chunk_size ) node_parser = SimpleNodeParser( text_splitter=text_splitter ) chroma_client = chromadb.Client() chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp") vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context) movie_list = ["Barbie (film)", "Oppenheimer (film)"] wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False) top_k = 3 vector_store_info = VectorStoreInfo( content_info="semantic information about movies", metadata_info=[MetadataInfo( name="title", type="str", description="title of the movie, one of [Barbie (film), Oppenheimer (film)]", )] ) class AutoRetrieveModel(BaseModel): query: str = Field(..., description="natural language query string") filter_key_list: List[str] = Field( ..., description="List of metadata filter field names" ) filter_value_list: List[str] = Field( ..., description=( "List of metadata filter field values (corresponding to names specified in filter_key_list)" ) ) def auto_retrieve_fn( query: str, filter_key_list: List[str], filter_value_list: List[str] ): """Auto retrieval function. Performs auto-retrieval from a vector database, and then applies a set of filters. """ query = query or "Query" exact_match_filters = [ ExactMatchFilter(key=k, value=v) for k, v in zip(filter_key_list, filter_value_list) ] retriever = VectorIndexRetriever( wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k ) query_engine = RetrieverQueryEngine.from_args(retriever) response = query_engine.query(query) return str(response) description = f"""\ Use this tool to look up semantic information about films. The vector database schema is given below: {vector_store_info.json()} """ auto_retrieve_tool = FunctionTool.from_defaults( fn=auto_retrieve_fn, name="auto_retrieve_tool", description=description, fn_schema=AutoRetrieveModel, ) agent = OpenAIAgent.from_tools( [auto_retrieve_tool], llm=llm, verbose=True ) @cl.author_rename def rename(orig_author: str): rename_dict = {"RetrievalQA": "Consulting The Llamaindex Tools"} return rename_dict.get(orig_author, orig_author) @cl.on_chat_start async def init(): msg = cl.Message(content=f"Building Index...") await msg.send() for movie, wiki_doc in zip(movie_list, wiki_docs): nodes = node_parser.get_nodes_from_documents([wiki_doc]) for node in nodes: node.metadata = {'title' : movie} wiki_vector_index.insert_nodes(nodes) chain = RetrievalQA.from_chain_type( ChatOpenAI(model="gpt-3.5-turbo", temperature=0, streaming=True), chain_type="stuff", return_source_documents=True, retriever=docsearch.as_retriever(), chain_type_kwargs = {"prompt": prompt} ) msg.content = f"Index built!" await msg.send() cl.user_session.set("chain", chain) @cl.on_message async def main(message): chain = cl.user_session.get("chain") cb = cl.AsyncLangchainCallbackHandler( stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] ) cb.answer_reached = True res = await chain.acall(message, callbacks=[cb], ) answer = res["result"] source_elements = [] visited_sources = set() # Get the documents from the user session docs = res["source_documents"] metadatas = [doc.metadata for doc in docs] all_sources = [m["source"] for m in metadatas] for source in all_sources: if source in visited_sources: continue visited_sources.add(source) # Create the text element referenced in the message source_elements.append( cl.Text(content="https://www.imdb.com" + source, name="Review URL") ) if source_elements: answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}" else: answer += "\nNo sources found" await cl.Message(content=answer, elements=source_elements).send()