jseims commited on
Commit
b9a7468
·
1 Parent(s): 61088db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -167
app.py CHANGED
@@ -1,185 +1,64 @@
1
- import chainlit as cl
2
- from llama_index import ServiceContext
3
- from llama_index.node_parser.simple import SimpleNodeParser
4
- from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
5
- from llama_index.llms import OpenAI
6
- from llama_index.embeddings.openai import OpenAIEmbedding
7
- from llama_index import VectorStoreIndex
8
- from llama_index.vector_stores import ChromaVectorStore
9
- from llama_index.storage.storage_context import StorageContext
10
- import chromadb
11
- from llama_index.readers.wikipedia import WikipediaReader
12
- from llama_index.tools import FunctionTool
13
- from llama_index.vector_stores.types import (
14
- VectorStoreInfo,
15
- MetadataInfo,
16
- ExactMatchFilter,
17
- MetadataFilters,
18
- )
19
- from llama_index.retrievers import VectorIndexRetriever
20
- from llama_index.query_engine import RetrieverQueryEngine
21
-
22
- from typing import List, Tuple, Any
23
- from pydantic import BaseModel, Field
24
- from llama_index.agent import OpenAIAgent
25
-
26
- embed_model = OpenAIEmbedding()
27
- chunk_size = 1000
28
- llm = OpenAI(
29
- temperature=0,
30
- model="gpt-3.5-turbo",
31
- streaming=True
32
- )
33
-
34
- service_context = ServiceContext.from_defaults(
35
- llm=llm,
36
- chunk_size=chunk_size,
37
- embed_model=embed_model
38
- )
39
-
40
- text_splitter = TokenTextSplitter(
41
- chunk_size=chunk_size
42
- )
43
-
44
- node_parser = SimpleNodeParser(
45
- text_splitter=text_splitter
46
  )
 
 
47
 
48
- chroma_client = chromadb.Client()
49
- chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp")
50
-
51
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
52
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
53
- wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)
54
-
55
- movie_list = ["Barbie (film)", "Oppenheimer (film)"]
56
 
57
- wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False)
 
 
58
 
59
- top_k = 3
60
- vector_store_info = VectorStoreInfo(
61
- content_info="semantic information about movies",
62
- metadata_info=[MetadataInfo(
63
- name="title",
64
- type="str",
65
- description="title of the movie, one of [Barbie (film), Oppenheimer (film)]",
66
- )]
67
- )
68
 
69
- class AutoRetrieveModel(BaseModel):
70
- query: str = Field(..., description="natural language query string")
71
- filter_key_list: List[str] = Field(
72
- ..., description="List of metadata filter field names"
 
 
 
 
73
  )
74
- filter_value_list: List[str] = Field(
75
- ...,
76
- description=(
77
- "List of metadata filter field values (corresponding to names specified in filter_key_list)"
78
- )
79
  )
80
 
81
- def auto_retrieve_fn(
82
- query: str, filter_key_list: List[str], filter_value_list: List[str]
83
- ):
84
- """Auto retrieval function.
85
-
86
- Performs auto-retrieval from a vector database, and then applies a set of filters.
87
-
88
- """
89
- query = query or "Query"
90
-
91
- exact_match_filters = [
92
- ExactMatchFilter(key=k, value=v)
93
- for k, v in zip(filter_key_list, filter_value_list)
94
- ]
95
- retriever = VectorIndexRetriever(
96
- wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
97
  )
98
- query_engine = RetrieverQueryEngine.from_args(retriever)
99
 
100
- response = query_engine.query(query)
101
- return str(response)
102
-
103
- description = f"""\
104
- Use this tool to look up semantic information about films.
105
- The vector database schema is given below:
106
- {vector_store_info.json()}
107
- """
108
-
109
- auto_retrieve_tool = FunctionTool.from_defaults(
110
- fn=auto_retrieve_fn,
111
- name="auto_retrieve_tool",
112
- description=description,
113
- fn_schema=AutoRetrieveModel,
114
- )
115
-
116
-
117
- agent = OpenAIAgent.from_tools(
118
- [auto_retrieve_tool], llm=llm, verbose=True
119
- )
120
-
121
- @cl.author_rename
122
- def rename(orig_author: str):
123
- rename_dict = {"RetrievalQA": "Consulting The Llamaindex Tools"}
124
- return rename_dict.get(orig_author, orig_author)
125
-
126
- @cl.on_chat_start
127
- async def init():
128
- msg = cl.Message(content=f"Building Index...")
129
- await msg.send()
130
-
131
- for movie, wiki_doc in zip(movie_list, wiki_docs):
132
- nodes = node_parser.get_nodes_from_documents([wiki_doc])
133
- for node in nodes:
134
- node.metadata = {'title' : movie}
135
- wiki_vector_index.insert_nodes(nodes)
136
-
137
-
138
-
139
- chain = RetrievalQA.from_chain_type(
140
- ChatOpenAI(model="gpt-3.5-turbo", temperature=0, streaming=True),
141
- chain_type="stuff",
142
- return_source_documents=True,
143
- retriever=docsearch.as_retriever(),
144
- chain_type_kwargs = {"prompt": prompt}
145
- )
146
-
147
- msg.content = f"Index built!"
148
- await msg.send()
149
-
150
- cl.user_session.set("chain", chain)
151
 
152
 
153
  @cl.on_message
154
  async def main(message):
155
- chain = cl.user_session.get("chain")
156
- cb = cl.AsyncLangchainCallbackHandler(
157
- stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
158
- )
159
- cb.answer_reached = True
160
- res = await chain.acall(message, callbacks=[cb], )
161
-
162
- answer = res["result"]
163
- source_elements = []
164
- visited_sources = set()
165
 
166
- # Get the documents from the user session
167
- docs = res["source_documents"]
168
- metadatas = [doc.metadata for doc in docs]
169
- all_sources = [m["source"] for m in metadatas]
170
 
171
- for source in all_sources:
172
- if source in visited_sources:
173
- continue
174
- visited_sources.add(source)
175
- # Create the text element referenced in the message
176
- source_elements.append(
177
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
178
- )
179
 
180
- if source_elements:
181
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
182
- else:
183
- answer += "\nNo sources found"
184
 
185
- await cl.Message(content=answer, elements=source_elements).send()
 
1
+ import os
2
+ import openai
3
+
4
+ from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
5
+ from llama_index.callbacks.base import CallbackManager
6
+ from llama_index import (
7
+ LLMPredictor,
8
+ ServiceContext,
9
+ StorageContext,
10
+ load_index_from_storage,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
+ from langchain.chat_models import ChatOpenAI
13
+ import chainlit as cl
14
 
15
+ try:
16
+ # rebuild storage context
17
+ storage_context = StorageContext.from_defaults(persist_dir="./storage")
18
+ # load index
19
+ index = load_index_from_storage(storage_context)
20
+ except:
21
+ from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader
 
22
 
23
+ documents = SimpleDirectoryReader("./data").load_data()
24
+ index = GPTVectorStoreIndex.from_documents(documents)
25
+ index.storage_context.persist()
26
 
 
 
 
 
 
 
 
 
 
27
 
28
+ @cl.on_chat_start
29
+ async def factory():
30
+ llm_predictor = LLMPredictor(
31
+ llm=ChatOpenAI(
32
+ temperature=0,
33
+ model_name="gpt-3.5-turbo",
34
+ streaming=True,
35
+ ),
36
  )
37
+ service_context = ServiceContext.from_defaults(
38
+ llm_predictor=llm_predictor,
39
+ chunk_size=512,
40
+ callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
 
41
  )
42
 
43
+ query_engine = index.as_query_engine(
44
+ service_context=service_context,
45
+ streaming=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
 
47
 
48
+ cl.user_session.set("query_engine", query_engine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  @cl.on_message
52
  async def main(message):
53
+ query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine
54
+ response = await cl.make_async(query_engine.query)(message)
 
 
 
 
 
 
 
 
55
 
56
+ response_message = cl.Message(content="")
 
 
 
57
 
58
+ for token in response.response_gen:
59
+ await response_message.stream_token(token=token)
 
 
 
 
 
 
60
 
61
+ if response.response_txt:
62
+ response_message.content = response.response_txt
 
 
63
 
64
+ await response_message.send()