Mattral commited on
Commit
2b45954
·
verified ·
1 Parent(s): 7432727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -11,7 +11,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from docarray import DocumentArray
15
  from sentence_transformers import SentenceTransformer
16
 
17
  # StreamHandler to intercept streaming output from the LLM.
@@ -26,6 +26,19 @@ class StreamHandler(BaseCallbackHandler):
26
  self.text += token
27
  self.container.markdown(self.text)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @st.cache_data
31
  def get_page_urls(url):
@@ -54,11 +67,13 @@ def get_url_content(url):
54
  text = ' '.join([c.get_text().strip() for c in content if c.get_text().strip() != ''])
55
 
56
  # Create a single document with metadata
57
- return DocumentArray([{'text': text, 'tags': {'url': url}}])
 
58
  except Exception as e:
59
  st.error(f"Failed to process URL content: {e}")
60
  return DocumentArray()
61
 
 
62
  @st.cache_resource
63
  def get_retriever(urls):
64
  documents = DocumentArray()
@@ -72,7 +87,7 @@ def get_retriever(urls):
72
  for doc, emb in zip(documents, embeddings):
73
  doc.embedding = emb
74
 
75
- return documents
76
 
77
 
78
  @st.cache_resource
 
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from docarray import Document, DocumentArray
15
  from sentence_transformers import SentenceTransformer
16
 
17
  # StreamHandler to intercept streaming output from the LLM.
 
26
  self.text += token
27
  self.container.markdown(self.text)
28
 
29
+ from langchain.retrievers import BaseRetriever
30
+
31
+ class SimpleEmbeddingRetriever(BaseRetriever):
32
+ def __init__(self, documents):
33
+ self.documents = documents
34
+
35
+ def _get_relevant_documents(self, query: str, num_documents: int = 5):
36
+ query_doc = Document(text=query)
37
+ query_embedding = self.documents.embeddings.model.encode([query_doc.text])[0]
38
+ query_doc.embedding = query_embedding
39
+ scores = self.documents.match(query_doc, limit=num_documents, metric='cosine', use_scipy=True)
40
+ return [(doc.text, score) for doc, score in scores]
41
+
42
 
43
  @st.cache_data
44
  def get_page_urls(url):
 
67
  text = ' '.join([c.get_text().strip() for c in content if c.get_text().strip() != ''])
68
 
69
  # Create a single document with metadata
70
+ document = Document(text=text, tags={'url': url})
71
+ return DocumentArray([document])
72
  except Exception as e:
73
  st.error(f"Failed to process URL content: {e}")
74
  return DocumentArray()
75
 
76
+
77
  @st.cache_resource
78
  def get_retriever(urls):
79
  documents = DocumentArray()
 
87
  for doc, emb in zip(documents, embeddings):
88
  doc.embedding = emb
89
 
90
+ return SimpleEmbeddingRetriever(documents)
91
 
92
 
93
  @st.cache_resource