Spaces:
Sleeping
Sleeping
File size: 1,920 Bytes
d1df841 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from dataclasses import dataclass, asdict
from pathlib import Path
import random
import numpy as np
from pymilvus import MilvusClient
@dataclass
class MilvusServer:
uri: str = "milvus.db"
@dataclass
class EmbeddingCollectionSchema:
collection_name: str
vector_field_name: str
dimension: int
auto_id: bool
enable_dynamic_field: bool
metric_type: str
ImageEmbeddingCollectionSchema = EmbeddingCollectionSchema(
collection_name="image_embeddings",
vector_field_name="embedding",
dimension=512,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
TextEmbeddingCollectionSchema = EmbeddingCollectionSchema(
collection_name="text_embeddings",
vector_field_name="embedding",
dimension=384,
auto_id=True,
enable_dynamic_field=True,
metric_type="COSINE",
)
class VectorDB:
def __init__(self, client: MilvusClient = MilvusClient(uri=MilvusServer.uri)):
self.client = client
def create_collection(self, schema: EmbeddingCollectionSchema):
if self.client.has_collection(collection_name=schema.collection_name):
print(f"Collection {schema.collection_name} already exists")
return True
# self.client.drop_collection(collection_name=schema.collection_name)
print(f"Creating collection {schema.collection_name}")
self.client.create_collection(**asdict(schema))
print(f"Collection {schema.collection_name} created")
return True
def insert_record(
self, collection_name: str, embedding: np.ndarray, file_path: str
) -> bool:
try:
self.client.insert(
collection_name=collection_name,
data={"embedding": embedding, "filename": file_path},
)
except Exception as e:
print(f"Error inserting record: {e}")
return False
return True
|