File size: 1,920 Bytes
d1df841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from dataclasses import dataclass, asdict
from pathlib import Path
import random
import numpy as np
from pymilvus import MilvusClient


@dataclass
class MilvusServer:
    uri: str = "milvus.db"


@dataclass
class EmbeddingCollectionSchema:
    collection_name: str
    vector_field_name: str
    dimension: int
    auto_id: bool
    enable_dynamic_field: bool
    metric_type: str


ImageEmbeddingCollectionSchema = EmbeddingCollectionSchema(
    collection_name="image_embeddings",
    vector_field_name="embedding",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

TextEmbeddingCollectionSchema = EmbeddingCollectionSchema(
    collection_name="text_embeddings",
    vector_field_name="embedding",
    dimension=384,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)


class VectorDB:

    def __init__(self, client: MilvusClient = MilvusClient(uri=MilvusServer.uri)):
        self.client = client

    def create_collection(self, schema: EmbeddingCollectionSchema):
        if self.client.has_collection(collection_name=schema.collection_name):
            print(f"Collection {schema.collection_name} already exists")
            return True
            # self.client.drop_collection(collection_name=schema.collection_name)
        print(f"Creating collection {schema.collection_name}")
        self.client.create_collection(**asdict(schema))
        print(f"Collection {schema.collection_name} created")
        return True

    def insert_record(
        self, collection_name: str, embedding: np.ndarray, file_path: str
    ) -> bool:
        try:
            self.client.insert(
                collection_name=collection_name,
                data={"embedding": embedding, "filename": file_path},
            )
        except Exception as e:
            print(f"Error inserting record: {e}")
            return False
        return True