File size: 2,173 Bytes
0d59440
92d8b8b
d6c88ae
 
 
 
 
0d59440
482cb2b
d6c88ae
42f52e5
b54da20
 
482cb2b
b54da20
 
 
0d59440
d6c88ae
e3bc95e
d6c88ae
 
0d59440
b54da20
 
 
 
 
 
0d8779b
489b7f2
 
 
 
0d59440
d6c88ae
e3bc95e
 
d6c88ae
e3bc95e
 
 
 
0d59440
d6c88ae
 
489b7f2
 
 
 
d6c88ae
 
489b7f2
d6c88ae
 
 
 
e3bc95e
 
0a6c9ba
482cb2b
b54da20
d6c88ae
 
 
 
 
 
 
b54da20
d6c88ae
 
b54da20
d6c88ae
f9d5dab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
st.set_page_config(page_title='Image Search App', layout='wide')
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import wget
from PIL import Image
# import io
from sentence_transformers import SentenceTransformer
import json
from zipfile import ZipFile
import zipfile
# from io import BytesIO
from PIL import Image
# from huggingface_hub import hf_hub_download
# hf_hub_download(repo_id="shivangibithel/Flickr8k", filename="Images.zip")

# Load the pre-trained sentence encoder
model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)

# Define the path to the zip folder containing the images
zip_path = "Images.zip"

# Open the zip folder
zip_file = zipfile.ZipFile(zip_path)

vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'

with open(image_map_name, 'r') as f:
    caption_dict = json.load(f)

image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())

def search(query, k=5):
    # Encode the query
    query_embedding = model.encode(query)
    query_vector = np.array([query_embedding])
    faiss.normalize_L2(query_vector)
    index.nprobe = index.ntotal

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_vector, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        text_id = i
        image_id = str(image_list[i])
        image_data = zip_file.open("Images/" +image_id)
        image = Image.open(image_data)
        st.image(image, caption=image_name, width=200)


st.title("Image Search App")

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        search(query)

        # Display the images
        # st.image(image_urls, width=200)

if __name__ == '__main__':    
    st.cache(allow_output_mutation=True)
    run_app()