File size: 2,055 Bytes
d6c88ae
0d59440
d6c88ae
 
 
 
 
0d59440
d6c88ae
 
0d59440
d6c88ae
e3bc95e
d6c88ae
 
0d59440
489b7f2
 
0d8779b
489b7f2
0d8779b
489b7f2
0d8779b
489b7f2
 
 
 
0d59440
d6c88ae
e3bc95e
0d8779b
d6c88ae
e3bc95e
d6c88ae
e3bc95e
 
 
 
0d59440
d6c88ae
 
489b7f2
 
 
 
d6c88ae
 
489b7f2
d6c88ae
 
 
 
e3bc95e
 
0d8779b
d6c88ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from datasets import load_dataset
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import wget
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer

# Load the pre-trained sentence encoder
model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)

# # Load the FAISS index
# index_name = 'index.faiss'
# index_url = './'
# wget.download(index_url, index_name)
# index = faiss.read_index(faiss_flickr8k.index)

vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'
image_map_url = './captions.json'
wget.download(image_map_url, image_map_name)

with open(image_map_name, 'r') as f:
    caption_dict = json.load(f)

image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())

def search(query, k=5):
    # Encode the query
    query_embedding = model.encode(query)
    query_vector = np.array([query_embedding])
    faiss.normalize_L2(query_vector)
    index.nprobe = index.ntotal

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_vector, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        text_id = i
        image_id = str(image_list[i])
        image_url = "./Images/" + image_id
        image_urls.append(image_url)

    return image_urls

st.title("Image Search App")

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        image_urls = search(query)

        # Display the images
        st.image(image_urls, width=200)

if __name__ == '__main__':
    st.set_page_config(page_title='Image Search App', layout='wide')
    st.cache(allow_output_mutation=True)
    run_app()