EmbedDior / embedding.py
HadrienCr's picture
Add examples
03830c7
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
def get_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "mps" if torch.mps.is_available() else device
return device
def get_model_and_processor(model_name: str, device: str):
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
return model.to(device), processor
def to_embedding(images: np.ndarray, processor, model, device: str):
img_proc = processor(images=images, return_tensors="pt")["pixel_values"]
with torch.no_grad():
img_emb = model.get_image_features(img_proc.to(device)).cpu().numpy()
return img_emb
def embed(batch, processor, model, device: str, remove_background: bool):
batch["embeddings" + ("_no_bg" if remove_background else "")] = to_embedding(
np.array(batch["image_no_bg" if remove_background else "image"]),
processor,
model,
device,
)
return batch