Joshua Sundance Bailey commited on
Commit
b012c56
·
unverified ·
2 Parent(s): fbe3579 c132355

Merge pull request #102 from connorsutton/main

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -16,7 +16,12 @@ from streamlit_feedback import streamlit_feedback
16
 
17
  from defaults import default_values
18
 
19
- from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
 
 
 
 
 
20
 
21
  __version__ = "1.0.3"
22
 
@@ -132,7 +137,7 @@ def get_texts_and_retriever_cacheable_wrapper(
132
  azure_kwargs: Optional[Dict[str, str]] = None,
133
  use_azure: bool = False,
134
  ) -> Tuple[List[Document], BaseRetriever]:
135
- return get_texts_and_retriever(
136
  uploaded_file_bytes=uploaded_file_bytes,
137
  openai_api_key=openai_api_key,
138
  chunk_size=chunk_size,
 
16
 
17
  from defaults import default_values
18
 
19
+ from llm_resources import (
20
+ get_runnable,
21
+ get_llm,
22
+ get_texts_and_multiretriever,
23
+ StreamHandler,
24
+ )
25
 
26
  __version__ = "1.0.3"
27
 
 
137
  azure_kwargs: Optional[Dict[str, str]] = None,
138
  use_azure: bool = False,
139
  ) -> Tuple[List[Document], BaseRetriever]:
140
+ return get_texts_and_multiretriever(
141
  uploaded_file_bytes=uploaded_file_bytes,
142
  openai_api_key=openai_api_key,
143
  chunk_size=chunk_size,
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -11,11 +11,16 @@ from langchain.chat_models import (
11
  )
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
14
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
15
  from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
 
 
 
 
 
 
19
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
20
  from qagen import get_rag_qa_gen_chain
21
  from summarize import get_rag_summarization_chain
@@ -111,7 +116,7 @@ def get_llm(
111
  return None
112
 
113
 
114
- def get_texts_and_retriever(
115
  uploaded_file_bytes: bytes,
116
  openai_api_key: str,
117
  chunk_size: int = DEFAULT_CHUNK_SIZE,
@@ -127,10 +132,23 @@ def get_texts_and_retriever(
127
  loader = PyPDFLoader(temp_file.name)
128
  documents = loader.load()
129
  text_splitter = RecursiveCharacterTextSplitter(
130
- chunk_size=chunk_size,
131
- chunk_overlap=chunk_overlap,
132
  )
 
 
133
  texts = text_splitter.split_documents(documents)
 
 
 
 
 
 
 
 
 
 
 
134
  embeddings_kwargs = {"openai_api_key": openai_api_key}
135
  if use_azure and azure_kwargs:
136
  azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
@@ -138,19 +156,35 @@ def get_texts_and_retriever(
138
  embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
139
  else:
140
  embeddings = OpenAIEmbeddings(**embeddings_kwargs)
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- bm25_retriever = BM25Retriever.from_documents(texts)
143
- bm25_retriever.k = k
144
-
145
- faiss_vectorstore = FAISS.from_documents(texts, embeddings)
146
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
 
 
 
 
 
 
147
 
148
  ensemble_retriever = EnsembleRetriever(
149
- retrievers=[bm25_retriever, faiss_retriever],
150
  weights=[0.5, 0.5],
151
  )
152
-
153
- return texts, ensemble_retriever
154
 
155
 
156
  class StreamHandler(BaseCallbackHandler):
 
11
  )
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
14
+ from langchain.retrievers import EnsembleRetriever
15
  from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
 
19
+ from langchain.retrievers.multi_query import MultiQueryRetriever
20
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
21
+ from langchain.storage import InMemoryStore
22
+ import uuid
23
+
24
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
25
  from qagen import get_rag_qa_gen_chain
26
  from summarize import get_rag_summarization_chain
 
116
  return None
117
 
118
 
119
+ def get_texts_and_multiretriever(
120
  uploaded_file_bytes: bytes,
121
  openai_api_key: str,
122
  chunk_size: int = DEFAULT_CHUNK_SIZE,
 
132
  loader = PyPDFLoader(temp_file.name)
133
  documents = loader.load()
134
  text_splitter = RecursiveCharacterTextSplitter(
135
+ chunk_size=10000,
136
+ chunk_overlap=0,
137
  )
138
+ child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
139
+
140
  texts = text_splitter.split_documents(documents)
141
+ id_key = "doc_id"
142
+
143
+ text_ids = [str(uuid.uuid4()) for _ in texts]
144
+ sub_texts = []
145
+ for i, text in enumerate(texts):
146
+ _id = text_ids[i]
147
+ _sub_texts = child_text_splitter.split_documents([text])
148
+ for _text in _sub_texts:
149
+ _text.metadata[id_key] = _id
150
+ sub_texts.extend(_sub_texts)
151
+
152
  embeddings_kwargs = {"openai_api_key": openai_api_key}
153
  if use_azure and azure_kwargs:
154
  azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
 
156
  embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
157
  else:
158
  embeddings = OpenAIEmbeddings(**embeddings_kwargs)
159
+ store = InMemoryStore()
160
+
161
+ # MultiVectorRetriever
162
+ multivectorstore = FAISS.from_documents(sub_texts, embeddings)
163
+ multivector_retriever = MultiVectorRetriever(
164
+ vectorstore=multivectorstore,
165
+ docstore=store,
166
+ id_key=id_key,
167
+ )
168
+ multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
169
+ # multivector_retriever.k = k
170
 
171
+ multiquery_text_splitter = RecursiveCharacterTextSplitter(
172
+ chunk_size=chunk_size,
173
+ chunk_overlap=chunk_overlap,
174
+ )
175
+ # MultiQueryRetriever
176
+ multiquery_texts = multiquery_text_splitter.split_documents(documents)
177
+ multiquerystore = FAISS.from_documents(multiquery_texts, embeddings)
178
+ multiquery_retriever = MultiQueryRetriever.from_llm(
179
+ retriever=multiquerystore.as_retriever(search_kwargs={"k": k}),
180
+ llm=ChatOpenAI(),
181
+ )
182
 
183
  ensemble_retriever = EnsembleRetriever(
184
+ retrievers=[multiquery_retriever, multivector_retriever],
185
  weights=[0.5, 0.5],
186
  )
187
+ return multiquery_texts, ensemble_retriever
 
188
 
189
 
190
  class StreamHandler(BaseCallbackHandler):
requirements.txt CHANGED
@@ -1,13 +1,12 @@
1
  anthropic==0.7.7
2
  faiss-cpu==1.7.4
3
- langchain==0.0.345
4
  langsmith==0.0.69
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
- openai==1.3.7
7
  pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
8
  pyarrow>=14.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
9
- pypdf==3.17.1
10
- rank_bm25==0.2.2
11
  streamlit==1.29.0
12
  streamlit-feedback==0.1.3
13
  tiktoken==0.5.2
 
1
  anthropic==0.7.7
2
  faiss-cpu==1.7.4
3
+ langchain==0.0.348
4
  langsmith==0.0.69
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
+ openai==1.3.8
7
  pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
8
  pyarrow>=14.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
9
+ pypdf==3.17.2
 
10
  streamlit==1.29.0
11
  streamlit-feedback==0.1.3
12
  tiktoken==0.5.2