|
import torch |
|
import numpy as np |
|
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
|
|
|
|
def get_device(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
return device |
|
|
|
|
|
def get_model_and_processor(model_name: str, device: str): |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name) |
|
return model.to(device), processor |
|
|
|
|
|
def to_embedding(images: np.ndarray, processor, model, device: str): |
|
img_proc = processor(images=images, return_tensors="pt")["pixel_values"] |
|
with torch.no_grad(): |
|
img_emb = model.get_image_features(img_proc.to(device)).cpu().numpy() |
|
return img_emb |
|
|
|
|
|
def embed(batch, processor, model, device: str, remove_background: bool): |
|
batch["embeddings" + ("_no_bg" if remove_background else "")] = to_embedding( |
|
np.array(batch["image_no_bg" if remove_background else "image"]), |
|
processor, |
|
model, |
|
device, |
|
) |
|
return batch |
|
|