Spaces:
Sleeping
Sleeping
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from tqdm import tqdm | |
import os | |
import time | |
import numpy as np | |
from loguru import logger | |
import stamina | |
from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict | |
class MyQdrantClient: | |
def __init__(self, path: str): | |
self.qdrant_client = QdrantClient(path=path) | |
logger.debug(f"Qdrant client created at {path}") | |
def create_collection(self, collection_name: str, vector_dim: int = 128, vector_type: str = "colbert"): | |
if vector_type == "colbert": | |
self.qdrant_client.create_collection( | |
collection_name=collection_name, | |
on_disk_payload=True, # store the payload on disk | |
vectors_config=models.VectorParams( | |
size=vector_dim, | |
distance=models.Distance.COSINE, | |
on_disk=True, # move original vectors to disk | |
multivector_config=models.MultiVectorConfig( | |
comparator=models.MultiVectorComparator.MAX_SIM | |
), | |
#quantization_config=models.BinaryQuantization( | |
#binary=models.BinaryQuantizationConfig( | |
# always_ram=True # keep only quantized vectors in RAM | |
# ), | |
#), | |
), | |
) | |
elif vector_type == "dense": | |
self.qdrant_client.create_collection( | |
collection_name=collection_name, | |
on_disk_payload=True, # store the payload on disk | |
vectors_config=models.VectorParams( | |
size=vector_dim, | |
distance=models.Distance.COSINE, | |
on_disk=True, # move original vectors to disk | |
), | |
) | |
else: | |
raise ValueError(f"Vector type {vector_type} not supported") | |
logger.debug(f"Qdrant collection of type {vector_type} : {collection_name} created") | |
def delete_collection(self, collection_name: str): | |
self.qdrant_client.delete_collection(collection_name=collection_name) | |
# retry mechanism if an exception occurs during the operation | |
def upsert_to_qdrant(self, batch, collection_name: str): | |
try: | |
self.qdrant_client.upsert( | |
collection_name=collection_name, | |
points=batch, | |
wait=False, | |
) | |
except Exception as e: | |
logger.error(f"Error during upsert: {e}") | |
return False | |
return True | |
def upsert_multivector(self, index: int, multivector_input_list: list[Any], collection_name: str): | |
try: | |
points = [] | |
for j, multivector in enumerate(multivector_input_list): | |
points.append( | |
models.PointStruct( | |
id=index + j, # we just use the index as the ID | |
vector=multivector, # This is now a list of vectors | |
payload={ | |
"source": "user uploaded data" | |
}, # can also add other metadata/data | |
) | |
) | |
# Upload points to Qdrant | |
self.upsert_to_qdrant(points, collection_name) | |
except Exception as e: | |
logger.error(f"Vector DB client - error during upsert: {e}") | |
def query_multivector(self, multivector_input, collection_name: str, top_k:int=10) -> list[int]: | |
try: | |
#logger.debug(f"Number of vector: {len(multivector_input)}") | |
#logger.debug(f"Vector dim: {len(multivector_input[0])}") | |
start_time = time.time() | |
search_result = self.qdrant_client.query_points( | |
collection_name=collection_name, | |
query=multivector_input, | |
limit=top_k, | |
# timeout=100, | |
# search_params=models.SearchParams( | |
# quantization=models.QuantizationSearchParams( | |
# ignore=False, | |
# rescore=True, | |
# oversampling=2.0, | |
# ) | |
# ) | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
logger.debug(f"Search completed in {elapsed_time:.4f} seconds") | |
result = [x.id for x in search_result.points] | |
return result | |
except Exception as e: | |
logger.error(f"Error during query: {e}") | |
return None | |
def __del__(self): | |
self.qdrant_client.close() | |