image-retrieval / utils /get_embeddings.py
nampham1106's picture
first commit
ab9b7a8
raw
history blame
870 Bytes
import os
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB')
processed_img = transform(img)
return processed_img
def create_resnet18_model():
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
modules = list(model.children())[:-1]
model = nn.Sequential(*modules)
return model
def extract_features(model, processed_image):
input = processed_image.unsqueeze(dim=0).to(device)
model.eval()
with torch.inference_mode():
prediction = model(input)
return prediction.squeeze().tolist()