Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
from langchain.chains import ConversationalRetrievalChain
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
-
from docarray import DocumentArray
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
|
17 |
# StreamHandler to intercept streaming output from the LLM.
|
@@ -26,6 +26,19 @@ class StreamHandler(BaseCallbackHandler):
|
|
26 |
self.text += token
|
27 |
self.container.markdown(self.text)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
@st.cache_data
|
31 |
def get_page_urls(url):
|
@@ -54,11 +67,13 @@ def get_url_content(url):
|
|
54 |
text = ' '.join([c.get_text().strip() for c in content if c.get_text().strip() != ''])
|
55 |
|
56 |
# Create a single document with metadata
|
57 |
-
|
|
|
58 |
except Exception as e:
|
59 |
st.error(f"Failed to process URL content: {e}")
|
60 |
return DocumentArray()
|
61 |
|
|
|
62 |
@st.cache_resource
|
63 |
def get_retriever(urls):
|
64 |
documents = DocumentArray()
|
@@ -72,7 +87,7 @@ def get_retriever(urls):
|
|
72 |
for doc, emb in zip(documents, embeddings):
|
73 |
doc.embedding = emb
|
74 |
|
75 |
-
return documents
|
76 |
|
77 |
|
78 |
@st.cache_resource
|
|
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
from langchain.chains import ConversationalRetrievalChain
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from docarray import Document, DocumentArray
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
|
17 |
# StreamHandler to intercept streaming output from the LLM.
|
|
|
26 |
self.text += token
|
27 |
self.container.markdown(self.text)
|
28 |
|
29 |
+
from langchain.retrievers import BaseRetriever
|
30 |
+
|
31 |
+
class SimpleEmbeddingRetriever(BaseRetriever):
|
32 |
+
def __init__(self, documents):
|
33 |
+
self.documents = documents
|
34 |
+
|
35 |
+
def _get_relevant_documents(self, query: str, num_documents: int = 5):
|
36 |
+
query_doc = Document(text=query)
|
37 |
+
query_embedding = self.documents.embeddings.model.encode([query_doc.text])[0]
|
38 |
+
query_doc.embedding = query_embedding
|
39 |
+
scores = self.documents.match(query_doc, limit=num_documents, metric='cosine', use_scipy=True)
|
40 |
+
return [(doc.text, score) for doc, score in scores]
|
41 |
+
|
42 |
|
43 |
@st.cache_data
|
44 |
def get_page_urls(url):
|
|
|
67 |
text = ' '.join([c.get_text().strip() for c in content if c.get_text().strip() != ''])
|
68 |
|
69 |
# Create a single document with metadata
|
70 |
+
document = Document(text=text, tags={'url': url})
|
71 |
+
return DocumentArray([document])
|
72 |
except Exception as e:
|
73 |
st.error(f"Failed to process URL content: {e}")
|
74 |
return DocumentArray()
|
75 |
|
76 |
+
|
77 |
@st.cache_resource
|
78 |
def get_retriever(urls):
|
79 |
documents = DocumentArray()
|
|
|
87 |
for doc, emb in zip(documents, embeddings):
|
88 |
doc.embedding = emb
|
89 |
|
90 |
+
return SimpleEmbeddingRetriever(documents)
|
91 |
|
92 |
|
93 |
@st.cache_resource
|