CA-Foundation / backend /Qdrant.py
“vinit5112”
Add all code
deb090d
raw
history blame
2.93 kB
import os
from qdrant_client import QdrantClient, models
from qdrant_client.models import PayloadSchemaType
import logging
from dotenv import load_dotenv
# Configure logging
logger = logging.getLogger(__name__)
load_dotenv()
# Configuration
# QDRANT_URL = "https://cc102304-2c06-4d51-9dee-d436f4413549.us-west-1-0.aws.cloud.qdrant.io"
# QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.cHs27o6erIf1BQHCdTxE4L4qZg4vCdrp51oNNNghjWM"
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")
class QdrantManager:
def __init__(self):
self.qdrant_client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
)
print("Connected to Qdrant")
def get_or_create_company_collection(self, collection_name: str) -> str:
"""
Get or create a collection for a company.
Args:
collection_name: Name of the collection
Returns:
str: Collection name
Raises:
ValueError: If collection creation fails
"""
try:
print(f"Creating new collection: {collection_name}")
# Vector size for text-embedding-3-small is 1536
vector_size = 1536
# Create collection with vector configuration
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
self.qdrant_client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {collection_name}")
return collection_name
except Exception as e:
error_msg = f"Failed to create collection {collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
raise ValueError(error_msg) from e
# Example usage
if __name__ == "__main__":
try:
qdrant_manager = QdrantManager()
collection_name = "ca-documents"
result = qdrant_manager.get_or_create_company_collection(collection_name)
print(f"Collection name: {result}")
except Exception as e:
print(f"Error: {e}")