tonyassi commited on
Commit
9038e96
·
verified ·
1 Parent(s): d9d9ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -53
app.py CHANGED
@@ -5,10 +5,10 @@ import random
5
  import numpy as np
6
  import time
7
 
8
-
9
- #ds = load_dataset("tonyassi/lucy4-embeddings", split='train')
10
  ds = load_dataset("tonyassi/finesse1-embeddings", split='train')
11
- #ds = load_dataset("tonyassi/lucy5-embeddings", split='train')
 
12
  id_to_row = {row['id']: row for row in ds}
13
  remaining_ds = None
14
  preference_embedding = []
@@ -28,41 +28,13 @@ def get_random_images(dataset, num):
28
 
29
  return random_images, new_dataset
30
 
31
- """
32
- def find_similar_images(dataset, num, embedding):
33
- start_time = time.time()
34
- # Find the most similar images in dataset
35
- dataset.add_faiss_index(column='embeddings')
36
- embedding = np.array(embedding)
37
- scores, retrieved_examples = dataset.get_nearest_examples('embeddings', embedding, k=num)
38
-
39
- print('time 2.1:', time.time()-start_time)
40
-
41
- # Create a new dataset without these images
42
- dataset.drop_index('embeddings')
43
- print('time 2.2:', time.time()-start_time)
44
- remaining_indices = [i for i in range(len(dataset)) if dataset[i]['id'] not in retrieved_examples['id']]
45
- print('time 2.3:', time.time()-start_time)
46
- new_dataset = dataset.select(remaining_indices)
47
-
48
- print('time 2.4:', time.time()-start_time)
49
- return retrieved_examples, new_dataset
50
-
51
- """
52
-
53
  def find_similar_images(dataset, num, embedding):
54
- start_time = time.time()
55
-
56
  # Ensure FAISS index exists and search for similar images
57
- #if not dataset.has_faiss_index('embeddings'):
58
  dataset.add_faiss_index(column='embeddings')
59
  scores, retrieved_examples = dataset.get_nearest_examples('embeddings', np.array(embedding), k=num)
60
 
61
- print('time 2.1:', time.time()-start_time)
62
-
63
  # Drop FAISS index after use to avoid re-indexing
64
  dataset.drop_index('embeddings')
65
- print('time 2.2:', time.time()-start_time)
66
 
67
  # Extract all dataset IDs and use a set to find remaining indices
68
  dataset_ids = dataset['id']
@@ -70,17 +42,12 @@ def find_similar_images(dataset, num, embedding):
70
 
71
  # Use a list comprehension with enumerate for faster indexing
72
  remaining_indices = [i for i, id in enumerate(dataset_ids) if id not in retrieved_ids_set]
73
-
74
- print('time 2.3:', time.time()-start_time)
75
 
76
  # Create a new dataset without the retrieved images
77
  new_dataset = dataset.select(remaining_indices)
78
 
79
- print('time 2.4:', time.time()-start_time)
80
  return retrieved_examples, new_dataset
81
 
82
-
83
-
84
  def average_embedding(embedding1, embedding2):
85
  embedding1 = np.array(embedding1)
86
  embedding2 = np.array(embedding2)
@@ -89,7 +56,6 @@ def average_embedding(embedding1, embedding2):
89
  ###################################################################################
90
 
91
  def load_images():
92
- print('load_images()')
93
  print("ds", ds.num_rows)
94
 
95
  global remaining_ds
@@ -108,23 +74,15 @@ def load_images():
108
 
109
 
110
  def select_image(evt: gr.SelectData, gallery, preference_gallery):
111
- start_time = time.time()
112
-
113
- print('select_image()')
114
-
115
  global remaining_ds
116
  print("remaining_ds", remaining_ds.num_rows)
117
 
118
  # Selected image
119
  selected_id = int(evt.value['caption'])
120
- print('ID', selected_id)
121
- #selected_row = ds.filter(lambda row: row['id'] == selected_id)[0]
122
  selected_row = id_to_row[selected_id]
123
  selected_embedding = selected_row['embeddings']
124
  selected_image = selected_row['image']
125
 
126
- print('time 1:', time.time()-start_time)
127
-
128
  # Update preference embedding
129
  global preference_embedding
130
  if len(preference_embedding) == 0:
@@ -132,18 +90,12 @@ def select_image(evt: gr.SelectData, gallery, preference_gallery):
132
  else:
133
  preference_embedding = average_embedding(preference_embedding, selected_embedding)
134
 
135
- print('time 2:', time.time()-start_time)
136
-
137
  # Find images which are most similar to the preference embedding
138
  simlar_images, remaining_ds = find_similar_images(remaining_ds, 5, preference_embedding)
139
 
140
- print('time 3:', time.time()-start_time)
141
-
142
  # Create a list of tuples [(img1,caption1),(img2,caption2)...]
143
  result = list(zip(simlar_images['image'], [str(id) for id in simlar_images['id']]))
144
 
145
- print('time 4:', time.time()-start_time)
146
-
147
  # Get random images
148
  rand_imgs, remaining_ds = get_random_images(remaining_ds, 5)
149
  # Create a list of tuples [(img1,caption1),(img2,caption2)...]
@@ -157,8 +109,6 @@ def select_image(evt: gr.SelectData, gallery, preference_gallery):
157
  else:
158
  final_preference_gallery = [selected_image] + preference_gallery
159
 
160
- print('time 5:', time.time()-start_time)
161
-
162
  return gr.Gallery(value=final_result, selected_index=None), final_preference_gallery
163
 
164
  ###################################################################################
 
5
  import numpy as np
6
  import time
7
 
8
+ # Dataset
 
9
  ds = load_dataset("tonyassi/finesse1-embeddings", split='train')
10
+
11
+
12
  id_to_row = {row['id']: row for row in ds}
13
  remaining_ds = None
14
  preference_embedding = []
 
28
 
29
  return random_images, new_dataset
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def find_similar_images(dataset, num, embedding):
 
 
32
  # Ensure FAISS index exists and search for similar images
 
33
  dataset.add_faiss_index(column='embeddings')
34
  scores, retrieved_examples = dataset.get_nearest_examples('embeddings', np.array(embedding), k=num)
35
 
 
 
36
  # Drop FAISS index after use to avoid re-indexing
37
  dataset.drop_index('embeddings')
 
38
 
39
  # Extract all dataset IDs and use a set to find remaining indices
40
  dataset_ids = dataset['id']
 
42
 
43
  # Use a list comprehension with enumerate for faster indexing
44
  remaining_indices = [i for i, id in enumerate(dataset_ids) if id not in retrieved_ids_set]
 
 
45
 
46
  # Create a new dataset without the retrieved images
47
  new_dataset = dataset.select(remaining_indices)
48
 
 
49
  return retrieved_examples, new_dataset
50
 
 
 
51
  def average_embedding(embedding1, embedding2):
52
  embedding1 = np.array(embedding1)
53
  embedding2 = np.array(embedding2)
 
56
  ###################################################################################
57
 
58
  def load_images():
 
59
  print("ds", ds.num_rows)
60
 
61
  global remaining_ds
 
74
 
75
 
76
  def select_image(evt: gr.SelectData, gallery, preference_gallery):
 
 
 
 
77
  global remaining_ds
78
  print("remaining_ds", remaining_ds.num_rows)
79
 
80
  # Selected image
81
  selected_id = int(evt.value['caption'])
 
 
82
  selected_row = id_to_row[selected_id]
83
  selected_embedding = selected_row['embeddings']
84
  selected_image = selected_row['image']
85
 
 
 
86
  # Update preference embedding
87
  global preference_embedding
88
  if len(preference_embedding) == 0:
 
90
  else:
91
  preference_embedding = average_embedding(preference_embedding, selected_embedding)
92
 
 
 
93
  # Find images which are most similar to the preference embedding
94
  simlar_images, remaining_ds = find_similar_images(remaining_ds, 5, preference_embedding)
95
 
 
 
96
  # Create a list of tuples [(img1,caption1),(img2,caption2)...]
97
  result = list(zip(simlar_images['image'], [str(id) for id in simlar_images['id']]))
98
 
 
 
99
  # Get random images
100
  rand_imgs, remaining_ds = get_random_images(remaining_ds, 5)
101
  # Create a list of tuples [(img1,caption1),(img2,caption2)...]
 
109
  else:
110
  final_preference_gallery = [selected_image] + preference_gallery
111
 
 
 
112
  return gr.Gallery(value=final_result, selected_index=None), final_preference_gallery
113
 
114
  ###################################################################################