File size: 2,552 Bytes
ab9b7a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from tqdm.auto import tqdm
from utils.utils import create_client
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility
from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model

COLLECTION_NAME = "Resnet18"
EMBEDDING_DIM = 512
IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr"

client = create_client()

def load_collection():
    check_collection = utility.has_collection(COLLECTION_NAME)
    if check_collection:
        print("Load and use collection right now!")
        collection = Collection(COLLECTION_NAME)
        collection.load()
        print(utility.load_state(COLLECTION_NAME))
    else:
        print("Please create a collection and insert data!")
        collection = create_collection()
        # insert data into collection
        model = create_resnet18_model()
        insert_data(model, collection, IMAGE_FOLDER)
        # create index for search
        create_index(collection)
    
    return collection

def  create_collection():
    image_id = FieldSchema(
        name="image_id",
        dtype=DataType.INT64,
        is_primary=True,
        description="Image ID"
    )
    
    image_embedding = FieldSchema(
        name="image_embedding",
        dtype=DataType.FLOAT_VECTOR,
        description="Image Embedding"
    )
    
    schema = CollectionSchema(
        fields=[image_id, image_embedding],
        auto_id=True,
        description="Image Retrieval using Resnet18"
    )
    
    collection = Collection(
        name=COLLECTION_NAME,
        schema=schema
    )
    return collection

def insert_data(model, collection, image_folder):
    image_ids = sorted([
        int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder)
    ])
    
    image_embeddings = []
    for image_name in tqdm(image_ids):
        file_name = str(image_name) + ".jpg"
        image_path = os.path.join(image_folder, file_name)
        processed_image = preprocess_image(image_path)
        processed_image = extract_features(model, processed_image)
        image_embeddings.append(processed_image)
    
    entities = [image_ids, image_embeddings]
    ins_resp = collection.insert(entities)
    collection.flush()
    

def create_index(collection):
    index_params = {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {}
    }

    collection.create_index(
        field_name=image_embedding.name,
        index_params=index_params
        )

    # load collection
    collection.load()