shivangibithel commited on
Commit
489b7f2
·
1 Parent(s): b1394c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -9,18 +9,22 @@ from PIL import Image
9
  from io import BytesIO
10
  from sentence_transformers import SentenceTransformer
11
 
12
- # dataset = load_dataset("imagefolder", data_files="https://huggingface.co/datasets/nlphuji/flickr30k/blob/main/flickr30k-images.zip")
13
-
14
  # Load the pre-trained sentence encoder
15
  model_name = "sentence-transformers/all-distilroberta-v1"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = SentenceTransformer(model_name)
18
 
19
- # Load the FAISS index
20
- index_name = 'index.faiss'
21
- index_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/faiss_flickr8k.index'
22
- wget.download(index_url, index_name)
23
- index = faiss.read_index(index_name)
 
 
 
 
 
 
24
 
25
  # Map the image ids to the corresponding image URLs
26
  image_map_name = 'captions.json'
@@ -35,11 +39,13 @@ caption_list = list(caption_dict.values())
35
 
36
  def search(query, k=5):
37
  # Encode the query
38
- query_tokens = tokenizer.encode(query, return_tensors='pt')
39
- query_embedding = model.encode(query_tokens).detach().numpy()
 
 
40
 
41
  # Search for the nearest neighbors in the FAISS index
42
- D, I = index.search(query_embedding, k)
43
 
44
  # Map the image ids to the corresponding image URLs
45
  image_urls = []
 
9
  from io import BytesIO
10
  from sentence_transformers import SentenceTransformer
11
 
 
 
12
  # Load the pre-trained sentence encoder
13
  model_name = "sentence-transformers/all-distilroberta-v1"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = SentenceTransformer(model_name)
16
 
17
+ # # Load the FAISS index
18
+ # index_name = 'index.faiss'
19
+ # index_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/faiss_flickr8k.index'
20
+ # wget.download(index_url, index_name)
21
+ # index = faiss.read_index(index_name)
22
+
23
+ vectors = np.load("https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/sbert_text_features.npy")
24
+ vector_dimension = vectors.shape[1]
25
+ index = faiss.IndexFlatL2(vector_dimension)
26
+ faiss.normalize_L2(vectors)
27
+ index.add(vectors)
28
 
29
  # Map the image ids to the corresponding image URLs
30
  image_map_name = 'captions.json'
 
39
 
40
  def search(query, k=5):
41
  # Encode the query
42
+ query_embedding = model.encode(query)
43
+ query_vector = np.array([query_embedding])
44
+ faiss.normalize_L2(query_vector)
45
+ index.nprobe = index.ntotal
46
 
47
  # Search for the nearest neighbors in the FAISS index
48
+ D, I = index.search(query_vector, k)
49
 
50
  # Map the image ids to the corresponding image URLs
51
  image_urls = []