Joshua Sundance Bailey commited on
Commit
457889e
·
1 Parent(s): 8aab446

parameterize k

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +14 -2
langchain-streamlit-demo/app.py CHANGED
@@ -124,12 +124,15 @@ MIN_CHUNK_OVERLAP = 0
124
  MAX_CHUNK_OVERLAP = 10000
125
  DEFAULT_CHUNK_OVERLAP = 0
126
 
 
 
127
 
128
  @st.cache_data
129
  def get_texts_and_retriever(
130
  uploaded_file_bytes: bytes,
131
  chunk_size: int = DEFAULT_CHUNK_SIZE,
132
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
 
133
  ) -> Tuple[List[Document], BaseRetriever]:
134
  with NamedTemporaryFile() as temp_file:
135
  temp_file.write(uploaded_file_bytes)
@@ -145,10 +148,10 @@ def get_texts_and_retriever(
145
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
146
 
147
  bm25_retriever = BM25Retriever.from_documents(texts)
148
- bm25_retriever.k = 4
149
 
150
  faiss_vectorstore = FAISS.from_documents(texts, embeddings)
151
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 4})
152
 
153
  ensemble_retriever = EnsembleRetriever(
154
  retrievers=[bm25_retriever, faiss_retriever],
@@ -200,6 +203,14 @@ with sidebar:
200
  help="Uploaded document will provide context for the chat.",
201
  )
202
 
 
 
 
 
 
 
 
 
203
  chunk_size = st.slider(
204
  label="chunk_size",
205
  help="Size of each chunk of text",
@@ -251,6 +262,7 @@ with sidebar:
251
  uploaded_file_bytes=uploaded_file.getvalue(),
252
  chunk_size=chunk_size,
253
  chunk_overlap=chunk_overlap,
 
254
  )
255
  else:
256
  st.error("Please enter a valid OpenAI API key.", icon="❌")
 
124
  MAX_CHUNK_OVERLAP = 10000
125
  DEFAULT_CHUNK_OVERLAP = 0
126
 
127
+ DEFAULT_RETRIEVER_K = 4
128
+
129
 
130
  @st.cache_data
131
  def get_texts_and_retriever(
132
  uploaded_file_bytes: bytes,
133
  chunk_size: int = DEFAULT_CHUNK_SIZE,
134
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
135
+ k: int = DEFAULT_RETRIEVER_K,
136
  ) -> Tuple[List[Document], BaseRetriever]:
137
  with NamedTemporaryFile() as temp_file:
138
  temp_file.write(uploaded_file_bytes)
 
148
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
149
 
150
  bm25_retriever = BM25Retriever.from_documents(texts)
151
+ bm25_retriever.k = k
152
 
153
  faiss_vectorstore = FAISS.from_documents(texts, embeddings)
154
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
155
 
156
  ensemble_retriever = EnsembleRetriever(
157
  retrievers=[bm25_retriever, faiss_retriever],
 
203
  help="Uploaded document will provide context for the chat.",
204
  )
205
 
206
+ k = st.slider(
207
+ label="Number of Chunks",
208
+ help="How many document chunks will be used for context?",
209
+ value=DEFAULT_RETRIEVER_K,
210
+ min_value=1,
211
+ max_value=10,
212
+ )
213
+
214
  chunk_size = st.slider(
215
  label="chunk_size",
216
  help="Size of each chunk of text",
 
262
  uploaded_file_bytes=uploaded_file.getvalue(),
263
  chunk_size=chunk_size,
264
  chunk_overlap=chunk_overlap,
265
+ k=k,
266
  )
267
  else:
268
  st.error("Please enter a valid OpenAI API key.", icon="❌")