Spaces:
Running
Running
import os | |
from qdrant_client import QdrantClient, models, grpc | |
from qdrant_client.http.models import PayloadSchemaType | |
import logging | |
from dotenv import load_dotenv | |
import asyncio | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
load_dotenv() | |
# Configuration | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
class QdrantManager: | |
def __init__(self): | |
self.qdrant_client = QdrantClient( | |
url=QDRANT_URL, | |
api_key=QDRANT_API_KEY, | |
prefer_grpc=True, | |
) | |
print("Connected to Qdrant") | |
async def get_or_create_company_collection(self, collection_name: str) -> str: | |
""" | |
Get or create a collection for a company. | |
Args: | |
collection_name: Name of the collection | |
Returns: | |
str: Collection name | |
Raises: | |
ValueError: If collection creation fails | |
""" | |
try: | |
print(f"Creating new collection: {collection_name}") | |
# Vector size for text-embedding-3-small is 1536 | |
vector_size = 384 | |
# Create collection with vector configuration | |
await self.qdrant_client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=vector_size, | |
distance=models.Distance.COSINE | |
), | |
hnsw_config=models.HnswConfigDiff( | |
payload_m=16, | |
m=0, | |
), | |
) | |
# Create payload indices | |
payload_indices = { | |
"document_id": PayloadSchemaType.KEYWORD, | |
"content": PayloadSchemaType.TEXT | |
} | |
for field_name, schema_type in payload_indices.items(): | |
await self.qdrant_client.create_payload_index( | |
collection_name=collection_name, | |
field_name=field_name, | |
field_schema=schema_type | |
) | |
print(f"Successfully created collection: {collection_name}") | |
return collection_name | |
except Exception as e: | |
error_msg = f"Failed to create collection {collection_name}: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
raise ValueError(error_msg) from e | |
# # Example usage | |
# async def main(): | |
# try: | |
# qdrant_manager = QdrantManager() | |
# collection_name = "ca-documents" | |
# result = await qdrant_manager.get_or_create_company_collection(collection_name) | |
# print(f"Collection name: {result}") | |
# except Exception as e: | |
# print(f"Error: {e}") | |
# if __name__ == "__main__": | |
# asyncio.run(main()) |