Spaces:
Sleeping
Sleeping
import os | |
from qdrant_client import QdrantClient, models | |
from qdrant_client.models import PayloadSchemaType | |
import logging | |
from dotenv import load_dotenv | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
load_dotenv() | |
# Configuration | |
# QDRANT_URL = "https://cc102304-2c06-4d51-9dee-d436f4413549.us-west-1-0.aws.cloud.qdrant.io" | |
# QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.cHs27o6erIf1BQHCdTxE4L4qZg4vCdrp51oNNNghjWM" | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
class QdrantManager: | |
def __init__(self): | |
self.qdrant_client = QdrantClient( | |
url=QDRANT_URL, | |
api_key=QDRANT_API_KEY, | |
) | |
print("Connected to Qdrant") | |
def get_or_create_company_collection(self, collection_name: str) -> str: | |
""" | |
Get or create a collection for a company. | |
Args: | |
collection_name: Name of the collection | |
Returns: | |
str: Collection name | |
Raises: | |
ValueError: If collection creation fails | |
""" | |
try: | |
print(f"Creating new collection: {collection_name}") | |
# Vector size for text-embedding-3-small is 1536 | |
vector_size = 1536 | |
# Create collection with vector configuration | |
self.qdrant_client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=vector_size, | |
distance=models.Distance.COSINE | |
), | |
hnsw_config=models.HnswConfigDiff( | |
payload_m=16, | |
m=0, | |
), | |
) | |
# Create payload indices | |
payload_indices = { | |
"document_id": PayloadSchemaType.KEYWORD, | |
"content": PayloadSchemaType.TEXT | |
} | |
for field_name, schema_type in payload_indices.items(): | |
self.qdrant_client.create_payload_index( | |
collection_name=collection_name, | |
field_name=field_name, | |
field_schema=schema_type | |
) | |
print(f"Successfully created collection: {collection_name}") | |
return collection_name | |
except Exception as e: | |
error_msg = f"Failed to create collection {collection_name}: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
raise ValueError(error_msg) from e | |
# Example usage | |
if __name__ == "__main__": | |
try: | |
qdrant_manager = QdrantManager() | |
collection_name = "ca-documents" | |
result = qdrant_manager.get_or_create_company_collection(collection_name) | |
print(f"Collection name: {result}") | |
except Exception as e: | |
print(f"Error: {e}") |