Spaces:
Running
Running
"""Tests for the time-weighted retriever class.""" | |
from datetime import datetime | |
from typing import Any, Iterable, List, Optional, Tuple, Type | |
import pytest | |
from langchain.embeddings.base import Embeddings | |
from langchain.retrievers.time_weighted_retriever import ( | |
TimeWeightedVectorStoreRetriever, | |
_get_hours_passed, | |
) | |
from langchain.schema import Document | |
from langchain.vectorstores.base import VectorStore | |
def _get_example_memories(k: int = 4) -> List[Document]: | |
return [ | |
Document( | |
page_content="foo", | |
metadata={ | |
"buffer_idx": i, | |
"last_accessed_at": datetime(2023, 4, 14, 12, 0), | |
}, | |
) | |
for i in range(k) | |
] | |
class MockVectorStore(VectorStore): | |
"""Mock invalid vector store.""" | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Run more texts through the embeddings and add to the vectorstore. | |
Args: | |
texts: Iterable of strings to add to the vectorstore. | |
metadatas: Optional list of metadatas associated with the texts. | |
kwargs: vectorstore specific parameters | |
Returns: | |
List of ids from adding the texts into the vectorstore. | |
""" | |
return list(texts) | |
async def aadd_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Run more texts through the embeddings and add to the vectorstore.""" | |
raise NotImplementedError | |
def similarity_search( | |
self, query: str, k: int = 4, **kwargs: Any | |
) -> List[Document]: | |
"""Return docs most similar to query.""" | |
return [] | |
def from_documents( | |
cls: Type["MockVectorStore"], | |
documents: List[Document], | |
embedding: Embeddings, | |
**kwargs: Any, | |
) -> "MockVectorStore": | |
"""Return VectorStore initialized from documents and embeddings.""" | |
texts = [d.page_content for d in documents] | |
metadatas = [d.metadata for d in documents] | |
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) | |
def from_texts( | |
cls: Type["MockVectorStore"], | |
texts: List[str], | |
embedding: Embeddings, | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> "MockVectorStore": | |
"""Return VectorStore initialized from texts and embeddings.""" | |
return cls() | |
def _similarity_search_with_relevance_scores( | |
self, | |
query: str, | |
k: int = 4, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float]]: | |
"""Return docs and similarity scores, normalized on a scale from 0 to 1. | |
0 is dissimilar, 1 is most similar. | |
""" | |
return [(doc, 0.5) for doc in _get_example_memories()] | |
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever: | |
vectorstore = MockVectorStore() | |
return TimeWeightedVectorStoreRetriever( | |
vectorstore=vectorstore, memory_stream=_get_example_memories() | |
) | |
def test__get_hours_passed() -> None: | |
time1 = datetime(2023, 4, 14, 14, 30) | |
time2 = datetime(2023, 4, 14, 12, 0) | |
expected_hours_passed = 2.5 | |
hours_passed = _get_hours_passed(time1, time2) | |
assert hours_passed == expected_hours_passed | |
def test_get_combined_score( | |
time_weighted_retriever: TimeWeightedVectorStoreRetriever, | |
) -> None: | |
document = Document( | |
page_content="Test document", | |
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)}, | |
) | |
vector_salience = 0.7 | |
expected_hours_passed = 2.5 | |
current_time = datetime(2023, 4, 14, 14, 30) | |
combined_score = time_weighted_retriever._get_combined_score( | |
document, vector_salience, current_time | |
) | |
expected_score = ( | |
1.0 - time_weighted_retriever.decay_rate | |
) ** expected_hours_passed + vector_salience | |
assert combined_score == pytest.approx(expected_score) | |
def test_get_salient_docs( | |
time_weighted_retriever: TimeWeightedVectorStoreRetriever, | |
) -> None: | |
query = "Test query" | |
docs_and_scores = time_weighted_retriever.get_salient_docs(query) | |
assert isinstance(docs_and_scores, dict) | |
def test_get_relevant_documents( | |
time_weighted_retriever: TimeWeightedVectorStoreRetriever, | |
) -> None: | |
query = "Test query" | |
relevant_documents = time_weighted_retriever.get_relevant_documents(query) | |
assert isinstance(relevant_documents, list) | |
def test_add_documents( | |
time_weighted_retriever: TimeWeightedVectorStoreRetriever, | |
) -> None: | |
documents = [Document(page_content="test_add_documents document")] | |
added_documents = time_weighted_retriever.add_documents(documents) | |
assert isinstance(added_documents, list) | |
assert len(added_documents) == 1 | |
assert ( | |
time_weighted_retriever.memory_stream[-1].page_content | |
== documents[0].page_content | |
) | |