CA-Foundation / backend /Qdrant.py
“vinit5112”
async changes
5b65de2
import os
from qdrant_client import QdrantClient, models, grpc
from qdrant_client.http.models import PayloadSchemaType
import logging
from dotenv import load_dotenv
import asyncio
# Configure logging
logger = logging.getLogger(__name__)
load_dotenv()
# Configuration
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")
class QdrantManager:
def __init__(self):
self.qdrant_client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
print("Connected to Qdrant")
async def get_or_create_company_collection(self, collection_name: str) -> str:
"""
Get or create a collection for a company.
Args:
collection_name: Name of the collection
Returns:
str: Collection name
Raises:
ValueError: If collection creation fails
"""
try:
print(f"Creating new collection: {collection_name}")
# Vector size for text-embedding-3-small is 1536
vector_size = 384
# Create collection with vector configuration
await self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
await self.qdrant_client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {collection_name}")
return collection_name
except Exception as e:
error_msg = f"Failed to create collection {collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
raise ValueError(error_msg) from e
# # Example usage
# async def main():
# try:
# qdrant_manager = QdrantManager()
# collection_name = "ca-documents"
# result = await qdrant_manager.get_or_create_company_collection(collection_name)
# print(f"Collection name: {result}")
# except Exception as e:
# print(f"Error: {e}")
# if __name__ == "__main__":
# asyncio.run(main())