vision experiments
Browse files- experimental/images/plant-001.jpeg +0 -0
- experimental/images/plant-002.jpeg +0 -0
- experimental/vision001.py +213 -0
- experimental/vision002.py +210 -0
experimental/images/plant-001.jpeg
ADDED
|
experimental/images/plant-002.jpeg
ADDED
|
experimental/vision001.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import requests
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from clip_app_client import ClipAppClient
|
| 11 |
+
from clip_retrieval.clip_client import ClipClient, Modality
|
| 12 |
+
clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
|
| 13 |
+
map_clip_to_clip_retreval = {
|
| 14 |
+
"ViT-L/14": "laion5B-L-14",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def safe_url(url):
|
| 19 |
+
import urllib.parse
|
| 20 |
+
url = urllib.parse.quote(url, safe=':/')
|
| 21 |
+
# if url has two .jpg filenames, take the first one
|
| 22 |
+
if url.count('.jpg') > 0:
|
| 23 |
+
url = url.split('.jpg')[0] + '.jpg'
|
| 24 |
+
return url
|
| 25 |
+
|
| 26 |
+
# test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png")
|
| 27 |
+
test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg")
|
| 28 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
|
| 29 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
|
| 30 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg")
|
| 31 |
+
|
| 32 |
+
app_client = ClipAppClient()
|
| 33 |
+
clip_retrieval_client = ClipClient(
|
| 34 |
+
url=clip_retrieval_service_url,
|
| 35 |
+
indice_name=map_clip_to_clip_retreval[app_client.clip_model],
|
| 36 |
+
# use_safety_model = False,
|
| 37 |
+
# use_violence_detector = False,
|
| 38 |
+
# use_mclip = False,
|
| 39 |
+
num_images = 300,
|
| 40 |
+
# modality = Modality.TEXT,
|
| 41 |
+
# modality = Modality.TEXT,
|
| 42 |
+
)
|
| 43 |
+
preprocessed_image = app_client.preprocess_image(test_image_path)
|
| 44 |
+
preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
|
| 45 |
+
print (f"embeddings: {preprocessed_image_embeddings.shape}")
|
| 46 |
+
|
| 47 |
+
embedding_as_list = preprocessed_image_embeddings[0].tolist()
|
| 48 |
+
results = clip_retrieval_client.query(embedding_input=embedding_as_list)
|
| 49 |
+
|
| 50 |
+
# hints = ""
|
| 51 |
+
# for result in results:
|
| 52 |
+
# url = safe_url(result["url"])
|
| 53 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 54 |
+
# title = result["caption"]
|
| 55 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 56 |
+
# if len(hints) > 0:
|
| 57 |
+
# hints += f", \"{title}\""
|
| 58 |
+
# else:
|
| 59 |
+
# hints += f"\"{title}\""
|
| 60 |
+
# print("---")
|
| 61 |
+
# print(hints)
|
| 62 |
+
|
| 63 |
+
image_labels = [r['caption'] for r in results]
|
| 64 |
+
image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels]
|
| 65 |
+
image_label_vectors = torch.cat(image_label_vectors, dim=0)
|
| 66 |
+
dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T)
|
| 67 |
+
similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
|
| 68 |
+
similarity_image_label.sort(reverse=True)
|
| 69 |
+
for similarity, image_label in similarity_image_label:
|
| 70 |
+
print (f"{similarity} {image_label}")
|
| 71 |
+
|
| 72 |
+
print (f"----\n")
|
| 73 |
+
|
| 74 |
+
# now do the same for images
|
| 75 |
+
def _safe_image_url_to_embedding(url, safe_return):
|
| 76 |
+
try:
|
| 77 |
+
return app_client.image_url_to_embedding(url)
|
| 78 |
+
except:
|
| 79 |
+
return safe_return
|
| 80 |
+
image_urls = [safe_url(r['url']) for r in results]
|
| 81 |
+
image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls]
|
| 82 |
+
image_vectors = torch.cat(image_vectors, dim=0)
|
| 83 |
+
dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T)
|
| 84 |
+
similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
|
| 85 |
+
similarity_image.sort(reverse=True)
|
| 86 |
+
for similarity, image_label in similarity_image:
|
| 87 |
+
print (f"{similarity} {image_label}")
|
| 88 |
+
|
| 89 |
+
def mean_template(embeddings):
|
| 90 |
+
template = torch.mean(embeddings, dim=0, keepdim=True)
|
| 91 |
+
return template
|
| 92 |
+
|
| 93 |
+
def principal_component_analysis_template(embeddings):
|
| 94 |
+
mean = torch.mean(embeddings, dim=0)
|
| 95 |
+
embeddings_centered = embeddings - mean # Subtract the mean
|
| 96 |
+
u, s, v = torch.svd(embeddings_centered) # Perform SVD
|
| 97 |
+
template = u[:, 0] # The first column of u gives the first principal component
|
| 98 |
+
return template
|
| 99 |
+
|
| 100 |
+
def clustering_templates(embeddings, n_clusters=5):
|
| 101 |
+
from sklearn.cluster import KMeans
|
| 102 |
+
import numpy as np
|
| 103 |
+
|
| 104 |
+
kmeans = KMeans(n_clusters=n_clusters)
|
| 105 |
+
embeddings_np = embeddings.numpy() # Convert to numpy
|
| 106 |
+
clusters = kmeans.fit_predict(embeddings_np)
|
| 107 |
+
|
| 108 |
+
templates = []
|
| 109 |
+
for cluster in np.unique(clusters):
|
| 110 |
+
cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0)
|
| 111 |
+
templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor
|
| 112 |
+
return templates
|
| 113 |
+
|
| 114 |
+
# create a templates using clustering
|
| 115 |
+
print(f"create a templates using clustering")
|
| 116 |
+
merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0)
|
| 117 |
+
clusters = clustering_templates(merged_embeddings, n_clusters=5)
|
| 118 |
+
# convert from list to 2d matrix
|
| 119 |
+
clusters = torch.stack(clusters, dim=0)
|
| 120 |
+
dot_product = torch.mm(clusters, preprocessed_image_embeddings.T)
|
| 121 |
+
cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))]
|
| 122 |
+
cluster_similarity.sort(reverse=True)
|
| 123 |
+
for similarity, idx in cluster_similarity:
|
| 124 |
+
print (f"{similarity} {idx}")
|
| 125 |
+
# template = highest scoring cluster
|
| 126 |
+
# template = clusters[cluster_similarity[0][1]]
|
| 127 |
+
template = preprocessed_image_embeddings * (len(clusters)-1)
|
| 128 |
+
for i in range(1, len(clusters)):
|
| 129 |
+
template -= clusters[cluster_similarity[i][1]]
|
| 130 |
+
print("---")
|
| 131 |
+
print(f"seaching based on template")
|
| 132 |
+
results = clip_retrieval_client.query(embedding_input=template[0].tolist())
|
| 133 |
+
hints = ""
|
| 134 |
+
for result in results:
|
| 135 |
+
url = safe_url(result["url"])
|
| 136 |
+
similarty = float("{:.4f}".format(result["similarity"]))
|
| 137 |
+
title = result["caption"]
|
| 138 |
+
print (f"{similarty} \"{title}\" {url}")
|
| 139 |
+
if len(hints) > 0:
|
| 140 |
+
hints += f", \"{title}\""
|
| 141 |
+
else:
|
| 142 |
+
hints += f"\"{title}\""
|
| 143 |
+
print(hints)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# cluster_num = 1
|
| 147 |
+
# for template in clusters:
|
| 148 |
+
# print("---")
|
| 149 |
+
# print(f"cluster {cluster_num} of {len(clusters)}")
|
| 150 |
+
# results = clip_retrieval_client.query(embedding_input=template.tolist())
|
| 151 |
+
# hints = ""
|
| 152 |
+
# for result in results:
|
| 153 |
+
# url = safe_url(result["url"])
|
| 154 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 155 |
+
# title = result["caption"]
|
| 156 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 157 |
+
# if len(hints) > 0:
|
| 158 |
+
# hints += f", \"{title}\""
|
| 159 |
+
# else:
|
| 160 |
+
# hints += f"\"{title}\""
|
| 161 |
+
# print(hints)
|
| 162 |
+
# cluster_num += 1
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# create a template
|
| 166 |
+
# mean
|
| 167 |
+
# image_label_template = mean_template(image_label_vectors)
|
| 168 |
+
# image_template = mean_template(image_vectors)
|
| 169 |
+
# pca
|
| 170 |
+
# image_label_template = principal_component_analysis_template(image_label_vectors)
|
| 171 |
+
# image_template = principal_component_analysis_template(image_vectors)
|
| 172 |
+
# clustering
|
| 173 |
+
# image_label_template = clustering_template(image_label_vectors)
|
| 174 |
+
# image_template = clustering_template(image_vectors)
|
| 175 |
+
|
| 176 |
+
# take the embedding and subtract the template
|
| 177 |
+
# image_label_template = preprocessed_image_embeddings - image_label_template
|
| 178 |
+
# image_template = preprocessed_image_embeddings - image_template
|
| 179 |
+
# image_label_template = image_label_template - preprocessed_image_embeddings
|
| 180 |
+
# image_template = image_template - preprocessed_image_embeddings
|
| 181 |
+
# normalize
|
| 182 |
+
# image_label_template = image_label_template / image_label_template.norm()
|
| 183 |
+
# image_template = image_template / image_template.norm()
|
| 184 |
+
|
| 185 |
+
# results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist())
|
| 186 |
+
# hints = ""
|
| 187 |
+
# print("---")
|
| 188 |
+
# print("average of image labels")
|
| 189 |
+
# for result in results:
|
| 190 |
+
# url = safe_url(result["url"])
|
| 191 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 192 |
+
# title = result["caption"]
|
| 193 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 194 |
+
# if len(hints) > 0:
|
| 195 |
+
# hints += f", \"{title}\""
|
| 196 |
+
# else:
|
| 197 |
+
# hints += f"\"{title}\""
|
| 198 |
+
# print(hints)
|
| 199 |
+
|
| 200 |
+
# print("---")
|
| 201 |
+
# print("average of images")
|
| 202 |
+
# results = clip_retrieval_client.query(embedding_input=image_template[0].tolist())
|
| 203 |
+
# hints = ""
|
| 204 |
+
# for result in results:
|
| 205 |
+
# url = safe_url(result["url"])
|
| 206 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 207 |
+
# title = result["caption"]
|
| 208 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 209 |
+
# if len(hints) > 0:
|
| 210 |
+
# hints += f", \"{title}\""
|
| 211 |
+
# else:
|
| 212 |
+
# hints += f"\"{title}\""
|
| 213 |
+
# print(hints)
|
experimental/vision002.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import requests
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from clip_app_client import ClipAppClient
|
| 11 |
+
from clip_retrieval.clip_client import ClipClient, Modality
|
| 12 |
+
clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
|
| 13 |
+
map_clip_to_clip_retreval = {
|
| 14 |
+
"ViT-L/14": "laion5B-L-14",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def safe_url(url):
|
| 19 |
+
import urllib.parse
|
| 20 |
+
url = urllib.parse.quote(url, safe=':/')
|
| 21 |
+
# if url has two .jpg filenames, take the first one
|
| 22 |
+
if url.count('.jpg') > 0:
|
| 23 |
+
url = url.split('.jpg')[0] + '.jpg'
|
| 24 |
+
return url
|
| 25 |
+
|
| 26 |
+
def _safe_image_url_to_embedding(url, safe_return):
|
| 27 |
+
try:
|
| 28 |
+
return app_client.image_url_to_embedding(url)
|
| 29 |
+
except:
|
| 30 |
+
return safe_return
|
| 31 |
+
|
| 32 |
+
def mean_template(embeddings):
|
| 33 |
+
template = torch.mean(embeddings, dim=0, keepdim=True)
|
| 34 |
+
return template
|
| 35 |
+
|
| 36 |
+
def principal_component_analysis_template(embeddings):
|
| 37 |
+
mean = torch.mean(embeddings, dim=0)
|
| 38 |
+
embeddings_centered = embeddings - mean # Subtract the mean
|
| 39 |
+
u, s, v = torch.svd(embeddings_centered) # Perform SVD
|
| 40 |
+
template = u[:, 0] # The first column of u gives the first principal component
|
| 41 |
+
return template
|
| 42 |
+
|
| 43 |
+
def clustering_templates(embeddings, n_clusters=5):
|
| 44 |
+
from sklearn.cluster import KMeans
|
| 45 |
+
import numpy as np
|
| 46 |
+
|
| 47 |
+
kmeans = KMeans(n_clusters=n_clusters)
|
| 48 |
+
embeddings_np = embeddings.numpy() # Convert to numpy
|
| 49 |
+
clusters = kmeans.fit_predict(embeddings_np)
|
| 50 |
+
|
| 51 |
+
templates = []
|
| 52 |
+
for cluster in np.unique(clusters):
|
| 53 |
+
cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0)
|
| 54 |
+
templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor
|
| 55 |
+
return templates
|
| 56 |
+
|
| 57 |
+
# test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png")
|
| 58 |
+
test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg")
|
| 59 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
|
| 60 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
|
| 61 |
+
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg")
|
| 62 |
+
|
| 63 |
+
app_client = ClipAppClient()
|
| 64 |
+
clip_retrieval_client = ClipClient(
|
| 65 |
+
url=clip_retrieval_service_url,
|
| 66 |
+
indice_name=map_clip_to_clip_retreval[app_client.clip_model],
|
| 67 |
+
# use_safety_model = False,
|
| 68 |
+
# use_violence_detector = False,
|
| 69 |
+
# use_mclip = False,
|
| 70 |
+
# num_images = 300,
|
| 71 |
+
# modality = Modality.TEXT,
|
| 72 |
+
# modality = Modality.TEXT,
|
| 73 |
+
)
|
| 74 |
+
preprocessed_image = app_client.preprocess_image(test_image_path)
|
| 75 |
+
preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
|
| 76 |
+
|
| 77 |
+
print (f"embeddings: {preprocessed_image_embeddings.shape}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
template = preprocessed_image_embeddings
|
| 81 |
+
for step_num in range(3):
|
| 82 |
+
print (f"\n\n---- Step {step_num} ----")
|
| 83 |
+
|
| 84 |
+
embedding_as_list = template[0].tolist()
|
| 85 |
+
results = clip_retrieval_client.query(embedding_input=embedding_as_list)
|
| 86 |
+
|
| 87 |
+
# get best matching labels
|
| 88 |
+
image_labels = [r['caption'] for r in results]
|
| 89 |
+
image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels]
|
| 90 |
+
image_label_vectors = torch.cat(image_label_vectors, dim=0)
|
| 91 |
+
dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T)
|
| 92 |
+
similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
|
| 93 |
+
similarity_image_label.sort(reverse=True)
|
| 94 |
+
for similarity, image_label in similarity_image_label:
|
| 95 |
+
print (f"{similarity} {image_label}")
|
| 96 |
+
|
| 97 |
+
# now do the same for images
|
| 98 |
+
image_urls = [safe_url(r['url']) for r in results]
|
| 99 |
+
image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls]
|
| 100 |
+
image_vectors = torch.cat(image_vectors, dim=0)
|
| 101 |
+
dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T)
|
| 102 |
+
similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
|
| 103 |
+
similarity_image.sort(reverse=True)
|
| 104 |
+
for similarity, image_label in similarity_image:
|
| 105 |
+
print (f"{similarity} {image_label}")
|
| 106 |
+
# remove images with low similarity as these will be images that did not load
|
| 107 |
+
image_vectors = torch.stack([image_vectors[i] for i in range(len(image_vectors)) if similarity_image[i][0] > 0.001], dim=0)
|
| 108 |
+
|
| 109 |
+
# create a templates using clustering
|
| 110 |
+
print(f"create a templates using clustering")
|
| 111 |
+
merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0)
|
| 112 |
+
# merged_embeddings = image_label_vectors # only use labels
|
| 113 |
+
# merged_embeddings = image_vectors # only use images
|
| 114 |
+
clusters = clustering_templates(merged_embeddings, n_clusters=5)
|
| 115 |
+
# convert from list to 2d matrix
|
| 116 |
+
clusters = torch.stack(clusters, dim=0)
|
| 117 |
+
dot_product = torch.mm(clusters, preprocessed_image_embeddings.T)
|
| 118 |
+
cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))]
|
| 119 |
+
cluster_similarity.sort(reverse=True)
|
| 120 |
+
for similarity, idx in cluster_similarity:
|
| 121 |
+
print (f"{similarity} {idx}")
|
| 122 |
+
# template = highest scoring cluster
|
| 123 |
+
# template = clusters[cluster_similarity[0][1]]
|
| 124 |
+
template = preprocessed_image_embeddings * (len(clusters)-1)
|
| 125 |
+
for i in range(1, len(clusters)):
|
| 126 |
+
template -= clusters[cluster_similarity[i][1]]
|
| 127 |
+
print("---")
|
| 128 |
+
print(f"seaching based on template")
|
| 129 |
+
results = clip_retrieval_client.query(embedding_input=template[0].tolist())
|
| 130 |
+
hints = ""
|
| 131 |
+
for result in results:
|
| 132 |
+
url = safe_url(result["url"])
|
| 133 |
+
similarty = float("{:.4f}".format(result["similarity"]))
|
| 134 |
+
title = result["caption"]
|
| 135 |
+
print (f"{similarty} \"{title}\" {url}")
|
| 136 |
+
if len(hints) > 0:
|
| 137 |
+
hints += f", \"{title}\""
|
| 138 |
+
else:
|
| 139 |
+
hints += f"\"{title}\""
|
| 140 |
+
print(hints)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# cluster_num = 1
|
| 144 |
+
# for template in clusters:
|
| 145 |
+
# print("---")
|
| 146 |
+
# print(f"cluster {cluster_num} of {len(clusters)}")
|
| 147 |
+
# results = clip_retrieval_client.query(embedding_input=template.tolist())
|
| 148 |
+
# hints = ""
|
| 149 |
+
# for result in results:
|
| 150 |
+
# url = safe_url(result["url"])
|
| 151 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 152 |
+
# title = result["caption"]
|
| 153 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 154 |
+
# if len(hints) > 0:
|
| 155 |
+
# hints += f", \"{title}\""
|
| 156 |
+
# else:
|
| 157 |
+
# hints += f"\"{title}\""
|
| 158 |
+
# print(hints)
|
| 159 |
+
# cluster_num += 1
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# create a template
|
| 163 |
+
# mean
|
| 164 |
+
# image_label_template = mean_template(image_label_vectors)
|
| 165 |
+
# image_template = mean_template(image_vectors)
|
| 166 |
+
# pca
|
| 167 |
+
# image_label_template = principal_component_analysis_template(image_label_vectors)
|
| 168 |
+
# image_template = principal_component_analysis_template(image_vectors)
|
| 169 |
+
# clustering
|
| 170 |
+
# image_label_template = clustering_template(image_label_vectors)
|
| 171 |
+
# image_template = clustering_template(image_vectors)
|
| 172 |
+
|
| 173 |
+
# take the embedding and subtract the template
|
| 174 |
+
# image_label_template = preprocessed_image_embeddings - image_label_template
|
| 175 |
+
# image_template = preprocessed_image_embeddings - image_template
|
| 176 |
+
# image_label_template = image_label_template - preprocessed_image_embeddings
|
| 177 |
+
# image_template = image_template - preprocessed_image_embeddings
|
| 178 |
+
# normalize
|
| 179 |
+
# image_label_template = image_label_template / image_label_template.norm()
|
| 180 |
+
# image_template = image_template / image_template.norm()
|
| 181 |
+
|
| 182 |
+
# results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist())
|
| 183 |
+
# hints = ""
|
| 184 |
+
# print("---")
|
| 185 |
+
# print("average of image labels")
|
| 186 |
+
# for result in results:
|
| 187 |
+
# url = safe_url(result["url"])
|
| 188 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 189 |
+
# title = result["caption"]
|
| 190 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 191 |
+
# if len(hints) > 0:
|
| 192 |
+
# hints += f", \"{title}\""
|
| 193 |
+
# else:
|
| 194 |
+
# hints += f"\"{title}\""
|
| 195 |
+
# print(hints)
|
| 196 |
+
|
| 197 |
+
# print("---")
|
| 198 |
+
# print("average of images")
|
| 199 |
+
# results = clip_retrieval_client.query(embedding_input=image_template[0].tolist())
|
| 200 |
+
# hints = ""
|
| 201 |
+
# for result in results:
|
| 202 |
+
# url = safe_url(result["url"])
|
| 203 |
+
# similarty = float("{:.4f}".format(result["similarity"]))
|
| 204 |
+
# title = result["caption"]
|
| 205 |
+
# print (f"{similarty} \"{title}\" {url}")
|
| 206 |
+
# if len(hints) > 0:
|
| 207 |
+
# hints += f", \"{title}\""
|
| 208 |
+
# else:
|
| 209 |
+
# hints += f"\"{title}\""
|
| 210 |
+
# print(hints)
|