File size: 3,598 Bytes
357b0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cde513
357b0b8
 
 
 
 
9cde513
 
 
357b0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st

from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel


BASELINE_MODEL = "openai/clip-vit-base-patch32"
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
MODEL_PATH = "flax-community/clip-rsicd"

# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"

# IMAGES_DIR = "/home/shared/data/rsicd_images"
IMAGES_DIR = "./images"


@st.cache(allow_output_mutation=True)
def load_index():
    filenames, image_vecs = [], []
    fvec = open(IMAGE_VECTOR_FILE, "r")
    for line in fvec:
        cols = line.strip().split('\t')
        filename = cols[0]
        image_vec = np.array([float(x) for x in cols[1].split(',')])
        filenames.append(filename)
        image_vecs.append(image_vec)
    V = np.array(image_vecs)
    index = nmslib.init(method='hnsw', space='cosinesimil')
    index.addDataPointBatch(V)
    index.createIndex({'post': 2}, print_progress=True)
    return filenames, index


@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
    processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
    return model, processor


def app():
    filenames, index = load_index()
    model, processor = load_model()

    st.title("Image to Image Retrieval")
    st.markdown("""
        The CLIP model from OpenAI is trained in a self-supervised manner using 
        contrastive learning to project images and caption text onto a common 
        embedding space. We have fine-tuned the model using the RSICD dataset
        (10k images and ~50k captions from the remote sensing domain).
        
        This demo shows the image to image retrieval capabilities of this model, i.e., 
        given an image file name as a query, we use our fine-tuned CLIP model 
        to project the query image to the image/caption embedding space and search 
        for nearby images (by cosine similarity) in this space.
        
        Our fine-tuned CLIP model was previously used to generate image vectors for 
        our demo, and NMSLib was used for fast vector access.

        You will need an image file name to start, we recommend copy pasting the
        file name from one of the results of the text to image search.
    """)

    image_file = st.text_input("Image Query (filename):")
    if st.button("Query"):
        image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
        inputs = processor(images=image, return_tensors="jax", padding=True)
        query_vec = model.get_image_features(**inputs)
        query_vec = np.asarray(query_vec)
        ids, distances = index.knnQuery(query_vec, k=11)
        result_filenames = [filenames[id] for id in ids]
        images, captions = [], []
        for result_filename, score in zip(result_filenames, distances):
            if result_filename == image_file:
                continue
            images.append(
                plt.imread(os.path.join(IMAGES_DIR, result_filename)))
            captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
        images = images[0:10]
        captions = captions[0:10]
        st.image(images[0:3], caption=captions[0:3])
        st.image(images[3:6], caption=captions[3:6])
        st.image(images[6:9], caption=captions[6:9])
        st.image(images[9:], caption=captions[9:])