Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import urllib | |
from PIL import Image | |
from torchvision import transforms | |
def load_model(): | |
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True) | |
model.eval() | |
mean = torch.tensor([0.485, 0.456, 0.406]) | |
std = torch.tensor([0.229, 0.224, 0.225]) | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
]) | |
postprocess = transforms.Compose([ | |
transforms.Normalize(mean=-mean/std, std=1/std), | |
transforms.ToPILImage(), | |
]) | |
if torch.cuda.is_available(): | |
model.to('cuda') | |
return model, preprocess | |
def remove_background(img, model, preprocess): | |
input_batch = preprocess(img)[None, ...] | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to('cuda') | |
with torch.no_grad(): | |
output = model(input_batch)['out'][0] | |
output_predictions = torch.nn.functional.softmax(output, dim=0) | |
output_predictions = (output_predictions > 0.98).float() | |
img.putalpha(255) | |
result_np = np.array(img) | |
result_np[..., 3] = (1-output_predictions[0].cpu().numpy())*255 | |
return Image.fromarray(result_np.astype('uint8')) | |
import os | |
def main(): | |
model, preprocess = load_model() | |
# fpath = 'data/parrot_2.png' | |
path_in = "/localhome/mta122/PycharmProjects/logo_ai/final_nocherry_score/one/DRAGON/G" | |
for fpath_file in os.listdir(path_in): | |
# fpath = 'data/parrot_2.png' | |
fpath = os.path.join(path_in, fpath_file) | |
# fpath_out = fpath.split('.')[0] + '_result_rembg.png' | |
# cmd = f'rembg i {fpath} {fpath_out}' | |
# print(cmd) | |
# os.system(cmd) | |
img = Image.open(fpath) | |
if img.size[-1] > 3: | |
img_np = np.array(img) | |
img_rbg = img_np[:, : ,:3] | |
img = Image.fromarray(img_rbg) | |
result = remove_background(img, model, preprocess) | |
result.save(fpath.split('.')[0] + '_result_deeplab.png') | |
print('finished') | |
main() | |