File size: 2,858 Bytes
deb090d
5b65de2
 
deb090d
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c3bbd6
deb090d
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
8146726
5b65de2
8146726
 
 
5b65de2
8146726
 
5b65de2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from qdrant_client import QdrantClient, models, grpc
from qdrant_client.http.models import PayloadSchemaType
import logging
from dotenv import load_dotenv
import asyncio

# Configure logging
logger = logging.getLogger(__name__)

load_dotenv()

# Configuration
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")

class QdrantManager:
    def __init__(self):
        self.qdrant_client = QdrantClient(
            url=QDRANT_URL,
            api_key=QDRANT_API_KEY,
            prefer_grpc=True,
        )
        print("Connected to Qdrant")

    async def get_or_create_company_collection(self, collection_name: str) -> str:
        """
        Get or create a collection for a company.
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            str: Collection name
            
        Raises:
            ValueError: If collection creation fails
        """
        try:
            print(f"Creating new collection: {collection_name}")

            # Vector size for text-embedding-3-small is 1536
            vector_size = 384
            
            # Create collection with vector configuration
            await self.qdrant_client.create_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(
                    size=vector_size,
                    distance=models.Distance.COSINE
                ),
                hnsw_config=models.HnswConfigDiff(
                    payload_m=16,
                    m=0,
                ),
            )
            
            # Create payload indices
            payload_indices = {
                "document_id": PayloadSchemaType.KEYWORD,
                "content": PayloadSchemaType.TEXT
            }
            
            for field_name, schema_type in payload_indices.items():
                await self.qdrant_client.create_payload_index(
                    collection_name=collection_name,
                    field_name=field_name,
                    field_schema=schema_type
                )
                    
            print(f"Successfully created collection: {collection_name}")
            return collection_name
            
        except Exception as e:
            error_msg = f"Failed to create collection {collection_name}: {str(e)}"
            logger.error(error_msg, exc_info=True)
            raise ValueError(error_msg) from e

# # Example usage
# async def main():
#     try:
#         qdrant_manager = QdrantManager()
#         collection_name = "ca-documents"
#         result = await qdrant_manager.get_or_create_company_collection(collection_name)
#         print(f"Collection name: {result}")
#     except Exception as e:
#         print(f"Error: {e}")

# if __name__ == "__main__":
#     asyncio.run(main())