Annas Dev
try vit
8f93744
raw
history blame
1.05 kB
from src.model import simlarity_model as model
from src.util import image as image_util
from src.util import matrix
from .model_implements.mobilenet_v3 import ModelnetV3
from .model_implements.vit_base import VitBase
class Similarity:
def get_models(self):
return [
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase()),
]
def check_similarity(self, img_urls, model):
# model = self.get_models()[model_idx]
imgs = []
for url in img_urls:
if url == "": continue
imgs.append(image_util.load_image_url(url, required_size=(model.image_size, model.image_size)))
features = model.model_cls.extract_feature(imgs)
for i, v in enumerate(features):
if i == 0: continue
dist = matrix.cosine(features[0], v)
# print(f'distance: {dist}')
return 'oke'