Spaces:
Sleeping
Sleeping
File size: 2,552 Bytes
ab9b7a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
from tqdm.auto import tqdm
from utils.utils import create_client
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility
from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model
COLLECTION_NAME = "Resnet18"
EMBEDDING_DIM = 512
IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr"
client = create_client()
def load_collection():
check_collection = utility.has_collection(COLLECTION_NAME)
if check_collection:
print("Load and use collection right now!")
collection = Collection(COLLECTION_NAME)
collection.load()
print(utility.load_state(COLLECTION_NAME))
else:
print("Please create a collection and insert data!")
collection = create_collection()
# insert data into collection
model = create_resnet18_model()
insert_data(model, collection, IMAGE_FOLDER)
# create index for search
create_index(collection)
return collection
def create_collection():
image_id = FieldSchema(
name="image_id",
dtype=DataType.INT64,
is_primary=True,
description="Image ID"
)
image_embedding = FieldSchema(
name="image_embedding",
dtype=DataType.FLOAT_VECTOR,
description="Image Embedding"
)
schema = CollectionSchema(
fields=[image_id, image_embedding],
auto_id=True,
description="Image Retrieval using Resnet18"
)
collection = Collection(
name=COLLECTION_NAME,
schema=schema
)
return collection
def insert_data(model, collection, image_folder):
image_ids = sorted([
int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder)
])
image_embeddings = []
for image_name in tqdm(image_ids):
file_name = str(image_name) + ".jpg"
image_path = os.path.join(image_folder, file_name)
processed_image = preprocess_image(image_path)
processed_image = extract_features(model, processed_image)
image_embeddings.append(processed_image)
entities = [image_ids, image_embeddings]
ins_resp = collection.insert(entities)
collection.flush()
def create_index(collection):
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {}
}
collection.create_index(
field_name=image_embedding.name,
index_params=index_params
)
# load collection
collection.load() |