tonyassi commited on
Commit
460f62f
·
verified ·
1 Parent(s): 9f08b2b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from datasets import load_dataset, Dataset
4
+ import random
5
+ import numpy as np
6
+ import time
7
+
8
+
9
+ #ds = load_dataset("tonyassi/lucy4-embeddings", split='train')
10
+ ds = load_dataset("tonyassi/finesse1-embeddings", split='train')
11
+ #ds = load_dataset("tonyassi/lucy5-embeddings", split='train')
12
+ id_to_row = {row['id']: row for row in ds}
13
+ remaining_ds = None
14
+ preference_embedding = []
15
+
16
+ ###################################################################################
17
+
18
+ def get_random_images(dataset, num):
19
+ # Select 4 random indices from the dataset
20
+ random_indices = random.sample(range(len(dataset)), num)
21
+
22
+ # Get the 4 random images
23
+ random_images = dataset.select(random_indices)
24
+
25
+ # Create a new dataset with the remaining images
26
+ remaining_indices = [i for i in range(len(dataset)) if i not in random_indices]
27
+ new_dataset = dataset.select(remaining_indices)
28
+
29
+ return random_images, new_dataset
30
+
31
+ """
32
+ def find_similar_images(dataset, num, embedding):
33
+ start_time = time.time()
34
+ # Find the most similar images in dataset
35
+ dataset.add_faiss_index(column='embeddings')
36
+ embedding = np.array(embedding)
37
+ scores, retrieved_examples = dataset.get_nearest_examples('embeddings', embedding, k=num)
38
+
39
+ print('time 2.1:', time.time()-start_time)
40
+
41
+ # Create a new dataset without these images
42
+ dataset.drop_index('embeddings')
43
+ print('time 2.2:', time.time()-start_time)
44
+ remaining_indices = [i for i in range(len(dataset)) if dataset[i]['id'] not in retrieved_examples['id']]
45
+ print('time 2.3:', time.time()-start_time)
46
+ new_dataset = dataset.select(remaining_indices)
47
+
48
+ print('time 2.4:', time.time()-start_time)
49
+ return retrieved_examples, new_dataset
50
+
51
+ """
52
+
53
+ def find_similar_images(dataset, num, embedding):
54
+ start_time = time.time()
55
+
56
+ # Ensure FAISS index exists and search for similar images
57
+ #if not dataset.has_faiss_index('embeddings'):
58
+ dataset.add_faiss_index(column='embeddings')
59
+ scores, retrieved_examples = dataset.get_nearest_examples('embeddings', np.array(embedding), k=num)
60
+
61
+ print('time 2.1:', time.time()-start_time)
62
+
63
+ # Drop FAISS index after use to avoid re-indexing
64
+ dataset.drop_index('embeddings')
65
+ print('time 2.2:', time.time()-start_time)
66
+
67
+ # Extract all dataset IDs and use a set to find remaining indices
68
+ dataset_ids = dataset['id']
69
+ retrieved_ids_set = set(retrieved_examples['id'])
70
+
71
+ # Use a list comprehension with enumerate for faster indexing
72
+ remaining_indices = [i for i, id in enumerate(dataset_ids) if id not in retrieved_ids_set]
73
+
74
+ print('time 2.3:', time.time()-start_time)
75
+
76
+ # Create a new dataset without the retrieved images
77
+ new_dataset = dataset.select(remaining_indices)
78
+
79
+ print('time 2.4:', time.time()-start_time)
80
+ return retrieved_examples, new_dataset
81
+
82
+
83
+
84
+ def average_embedding(embedding1, embedding2):
85
+ embedding1 = np.array(embedding1)
86
+ embedding2 = np.array(embedding2)
87
+ return (embedding1 + embedding2) / 2
88
+
89
+ ###################################################################################
90
+
91
+ def load_images():
92
+ print('load_images()')
93
+ print("ds", ds.num_rows)
94
+
95
+ global remaining_ds
96
+ remaining_ds = ds
97
+
98
+ global preference_embedding
99
+ preference_embedding = []
100
+
101
+ # Get random images
102
+ rand_imgs, remaining_ds = get_random_images(ds, 10)
103
+
104
+ # Create a list of tuples [(img1,caption1),(img2,caption2)...]
105
+ result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']]))
106
+
107
+ return result
108
+
109
+
110
+ def select_image(evt: gr.SelectData, gallery, preference_gallery):
111
+ start_time = time.time()
112
+
113
+ print('select_image()')
114
+
115
+ global remaining_ds
116
+ print("remaining_ds", remaining_ds.num_rows)
117
+
118
+ # Selected image
119
+ selected_id = int(evt.value['caption'])
120
+ print('ID', selected_id)
121
+ #selected_row = ds.filter(lambda row: row['id'] == selected_id)[0]
122
+ selected_row = id_to_row[selected_id]
123
+ selected_embedding = selected_row['embeddings']
124
+ selected_image = selected_row['image']
125
+
126
+ print('time 1:', time.time()-start_time)
127
+
128
+ # Update preference embedding
129
+ global preference_embedding
130
+ if len(preference_embedding) == 0:
131
+ preference_embedding = selected_embedding
132
+ else:
133
+ preference_embedding = average_embedding(preference_embedding, selected_embedding)
134
+
135
+ print('time 2:', time.time()-start_time)
136
+
137
+ # Find images which are most similar to the preference embedding
138
+ simlar_images, remaining_ds = find_similar_images(remaining_ds, 5, preference_embedding)
139
+
140
+ print('time 3:', time.time()-start_time)
141
+
142
+ # Create a list of tuples [(img1,caption1),(img2,caption2)...]
143
+ result = list(zip(simlar_images['image'], [str(id) for id in simlar_images['id']]))
144
+
145
+ print('time 4:', time.time()-start_time)
146
+
147
+ # Get random images
148
+ rand_imgs, remaining_ds = get_random_images(remaining_ds, 5)
149
+ # Create a list of tuples [(img1,caption1),(img2,caption2)...]
150
+ random_result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']]))
151
+
152
+ final_result = result + random_result
153
+
154
+ # Update prefernce gallery
155
+ if (preference_gallery==None):
156
+ final_preference_gallery = [selected_image]
157
+ else:
158
+ final_preference_gallery = [selected_image] + preference_gallery
159
+
160
+ print('time 5:', time.time()-start_time)
161
+
162
+ return gr.Gallery(value=final_result, selected_index=None), final_preference_gallery
163
+
164
+ ###################################################################################
165
+
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown("""
168
+ <center><h1> Product Recommendation using Image Similarity </h1></center>
169
+
170
+ <center>by <a href="https://www.tonyassi.com/" target="_blank">Tony Assi</a></center>
171
+
172
+
173
+ <center> This is a demo of product recommendation using image similarity of user preferences. </center>
174
+
175
+ The the user selects their favorite product which then gets added to the user preference group. Each of the image embeddings in the user preference products get averaged into a preference embedding. Each round some products are displayed: 5 products most similar to user preference embedding and 5 random products. Embeddings are generated with [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224). The dataset used is [tonyassi/finesse1-embeddings](https://huggingface.co/datasets/tonyassi/finesse1-embeddings).
176
+ """)
177
+
178
+ product_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Products')
179
+ preference_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Preference', interactive=False)
180
+
181
+ demo.load(load_images, inputs=None, outputs=[product_gallery])
182
+ product_gallery.select(select_image, inputs=[product_gallery, preference_gallery], outputs=[product_gallery, preference_gallery])
183
+
184
+
185
+ demo.launch()