Spaces:
Sleeping
Sleeping
import uuid | |
from abc import ABC | |
from typing import Any, Callable, Dict, Generic, Type, TypeVar | |
from uuid import UUID | |
import numpy as np | |
from loguru import logger | |
from pydantic import UUID4, BaseModel, Field | |
from qdrant_client.http import exceptions | |
from qdrant_client.http.models import Distance, VectorParams | |
from qdrant_client.models import CollectionInfo, PointStruct, Record | |
from rag_demo.infra.qdrant import connection | |
T = TypeVar("T", bound="VectorBaseDocument") | |
EMBEDDING_SIZE = 1024 | |
class VectorBaseDocument(BaseModel, Generic[T], ABC): | |
id: UUID4 = Field(default_factory=uuid.uuid4) | |
def __eq__(self, value: object) -> bool: | |
if not isinstance(value, self.__class__): | |
return False | |
return self.id == value.id | |
def __hash__(self) -> int: | |
return hash(self.id) | |
def from_record(cls: Type[T], point: Record) -> T: | |
_id = UUID(point.id, version=4) | |
payload = point.payload or {} | |
attributes = { | |
"id": _id, | |
**payload, | |
} | |
if cls._has_class_attribute("embedding"): | |
attributes["embedding"] = point.vector or None | |
return cls(**attributes) | |
def to_point(self: T, **kwargs) -> PointStruct: | |
exclude_unset = kwargs.pop("exclude_unset", False) | |
by_alias = kwargs.pop("by_alias", True) | |
payload = self.model_dump( | |
exclude_unset=exclude_unset, by_alias=by_alias, **kwargs | |
) | |
_id = str(payload.pop("id")) | |
vector = payload.pop("embedding", {}) | |
if vector and isinstance(vector, np.ndarray): | |
vector = vector.tolist() | |
return PointStruct(id=_id, vector=vector, payload=payload) | |
def model_dump(self: T, **kwargs) -> dict: | |
dict_ = super().model_dump(**kwargs) | |
dict_ = self._uuid_to_str(dict_) | |
return dict_ | |
def _uuid_to_str(self, item: Any) -> Any: | |
if isinstance(item, dict): | |
for key, value in item.items(): | |
if isinstance(value, UUID): | |
item[key] = str(value) | |
elif isinstance(value, list): | |
item[key] = [self._uuid_to_str(v) for v in value] | |
elif isinstance(value, dict): | |
item[key] = {k: self._uuid_to_str(v) for k, v in value.items()} | |
return item | |
def bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> bool: | |
try: | |
cls._bulk_insert(documents) | |
logger.info( | |
f"Successfully inserted {len(documents)} documents into {cls.get_collection_name()}" | |
) | |
except Exception as e: | |
logger.error(f"Error inserting documents: {e}") | |
logger.info( | |
f"Collection '{cls.get_collection_name()}' does not exist. Trying to create the collection and reinsert the documents." | |
) | |
cls.create_collection() | |
try: | |
cls._bulk_insert(documents) | |
except Exception as e: | |
logger.error(f"Error inserting documents: {e}") | |
logger.error( | |
f"Failed to insert documents in '{cls.get_collection_name()}'." | |
) | |
return False | |
return True | |
def _bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> None: | |
points = [doc.to_point() for doc in documents] | |
connection.upsert(collection_name=cls.get_collection_name(), points=points) | |
def bulk_find( | |
cls: Type[T], limit: int = 10, **kwargs | |
) -> tuple[list[T], UUID | None]: | |
try: | |
documents, next_offset = cls._bulk_find(limit=limit, **kwargs) | |
except exceptions.UnexpectedResponse: | |
logger.error( | |
f"Failed to search documents in '{cls.get_collection_name()}'." | |
) | |
documents, next_offset = [], None | |
return documents, next_offset | |
def _bulk_find( | |
cls: Type[T], limit: int = 10, **kwargs | |
) -> tuple[list[T], UUID | None]: | |
collection_name = cls.get_collection_name() | |
offset = kwargs.pop("offset", None) | |
offset = str(offset) if offset else None | |
records, next_offset = connection.scroll( | |
collection_name=collection_name, | |
limit=limit, | |
with_payload=kwargs.pop("with_payload", True), | |
with_vectors=kwargs.pop("with_vectors", False), | |
offset=offset, | |
**kwargs, | |
) | |
documents = [cls.from_record(record) for record in records] | |
if next_offset is not None: | |
next_offset = UUID(next_offset, version=4) | |
return documents, next_offset | |
def search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]: | |
try: | |
documents = cls._search(query_vector=query_vector, limit=limit, **kwargs) | |
except exceptions.UnexpectedResponse: | |
logger.error( | |
f"Failed to search documents in '{cls.get_collection_name()}'." | |
) | |
documents = [] | |
return documents | |
def _search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]: | |
collection_name = cls.get_collection_name() | |
records = connection.search( | |
collection_name=collection_name, | |
query_vector=query_vector, | |
limit=limit, | |
with_payload=kwargs.pop("with_payload", True), | |
with_vectors=kwargs.pop("with_vectors", False), | |
**kwargs, | |
) | |
documents = [cls.from_record(record) for record in records] | |
return documents | |
def get_or_create_collection(cls: Type[T]) -> CollectionInfo: | |
collection_name = cls.get_collection_name() | |
try: | |
return connection.get_collection(collection_name=collection_name) | |
except exceptions.UnexpectedResponse: | |
use_vector_index = cls.get_use_vector_index() | |
collection_created = cls._create_collection( | |
collection_name=collection_name, use_vector_index=use_vector_index | |
) | |
if collection_created is False: | |
raise RuntimeError( | |
f"Couldn't create collection {collection_name}" | |
) from None | |
return connection.get_collection(collection_name=collection_name) | |
def create_collection(cls: Type[T]) -> bool: | |
collection_name = cls.get_collection_name() | |
use_vector_index = cls.get_use_vector_index() | |
logger.info( | |
f"Creating collection {collection_name} with use_vector_index={use_vector_index}" | |
) | |
return cls._create_collection( | |
collection_name=collection_name, use_vector_index=use_vector_index | |
) | |
def _create_collection( | |
cls, collection_name: str, use_vector_index: bool = True | |
) -> bool: | |
if use_vector_index is True: | |
vectors_config = VectorParams(size=EMBEDDING_SIZE, distance=Distance.COSINE) | |
else: | |
vectors_config = {} | |
return connection.create_collection( | |
collection_name=collection_name, vectors_config=vectors_config | |
) | |
def get_collection_name(cls: Type[T]) -> str: | |
if not hasattr(cls, "Config") or not hasattr(cls.Config, "name"): | |
raise Exception( | |
f"The class {cls} should define a Config class with the 'name' property that reflects the collection's name." | |
) | |
return cls.Config.name | |
def get_use_vector_index(cls: Type[T]) -> bool: | |
if not hasattr(cls, "Config") or not hasattr(cls.Config, "use_vector_index"): | |
return True | |
return cls.Config.use_vector_index | |
def group_by_class( | |
cls: Type["VectorBaseDocument"], documents: list["VectorBaseDocument"] | |
) -> Dict["VectorBaseDocument", list["VectorBaseDocument"]]: | |
return cls._group_by(documents, selector=lambda doc: doc.__class__) | |
def _group_by( | |
cls: Type[T], documents: list[T], selector: Callable[[T], Any] | |
) -> Dict[Any, list[T]]: | |
grouped = {} | |
for doc in documents: | |
key = selector(doc) | |
if key not in grouped: | |
grouped[key] = [] | |
grouped[key].append(doc) | |
return grouped | |
def collection_name_to_class( | |
cls: Type["VectorBaseDocument"], collection_name: str | |
) -> type["VectorBaseDocument"]: | |
for subclass in cls.__subclasses__(): | |
try: | |
if subclass.get_collection_name() == collection_name: | |
return subclass | |
except Exception: | |
pass | |
try: | |
return subclass.collection_name_to_class(collection_name) | |
except ValueError: | |
continue | |
raise ValueError(f"No subclass found for collection name: {collection_name}") | |
def _has_class_attribute(cls: Type[T], attribute_name: str) -> bool: | |
if attribute_name in cls.__annotations__: | |
return True | |
for base in cls.__bases__: | |
if hasattr(base, "_has_class_attribute") and base._has_class_attribute( | |
attribute_name | |
): | |
return True | |
return False | |