remove_bg_api / handler.py
whlzy's picture
Upload folder using huggingface_hub
f9cfd2a verified
raw
history blame
1.55 kB
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
import base64
from io import BytesIO
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForImageSegmentation.from_pretrained(
'whlzy/remove_bg_api',
trust_remote_code=True,
token=os.environ.get("HUGGINGFACE_TOKEN")
)
self.model.to(device)
self.model.eval()
image_size = (1024, 1024)
self.transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
image = data.pop("inputs", data)
image = self.decode_base64_image(image)
input_images = self.transform_image(image).unsqueeze(0).to('cuda')
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
return image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image