File size: 2,343 Bytes
d6c88ae
0d59440
d6c88ae
 
 
 
 
0d59440
d6c88ae
 
0d59440
d6c88ae
 
0d59440
d6c88ae
 
 
 
0d59440
d6c88ae
 
 
 
 
 
0d59440
d6c88ae
 
 
 
 
0d59440
d6c88ae
 
 
 
 
 
 
0d59440
d6c88ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from datasets import load_dataset
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import wget
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer

# dataset = load_dataset("nlphuji/flickr30k", streaming=True)
# df = pd.DataFrame.from_dict(dataset["train"])

# Load the pre-trained sentence encoder
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)

# # Load the pre-trained image model
# image_model_name = 'image_model.ckpt'
# image_model_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/images/vqvae.png'
# wget.download(image_model_url, image_model_name)
# image_model = torch.load(image_model_name, map_location=torch.device('cpu'))
# image_model.eval()

# Load the FAISS index
index_name = 'index.faiss'
index_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/faiss.index'
wget.download(index_url, index_name)
index = faiss.read_index(index_name)

# Map the image ids to the corresponding image URLs
image_map_name = 'image_map.json'
image_map_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/image_map.json'
wget.download(image_map_url, image_map_name)
image_map = {}
with open(image_map_name, 'r') as f:
    image_map = json.load(f)

def search(query, k=5):
    # Encode the query
    query_tokens = tokenizer.encode(query, return_tensors='pt')
    query_embedding = model.encode(query_tokens).detach().numpy()

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_embedding, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        image_id = str(i)
        image_url = image_map[image_id]
        image_urls.append(image_url)

    return image_urls

st.title("Image Search App")

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        image_urls = search(query)

        # Display the images
        st.image(image_urls, width=200)

if __name__ == '__main__':
    st.set_page_config(page_title='Image Search App', layout='wide')
    st.cache(allow_output_mutation=True)
    run_app()