Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import urllib | |
| from PIL import Image | |
| from torchvision import transforms | |
| def load_model(): | |
| model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True) | |
| model.eval() | |
| mean = torch.tensor([0.485, 0.456, 0.406]) | |
| std = torch.tensor([0.229, 0.224, 0.225]) | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| postprocess = transforms.Compose([ | |
| transforms.Normalize(mean=-mean/std, std=1/std), | |
| transforms.ToPILImage(), | |
| ]) | |
| if torch.cuda.is_available(): | |
| model.to('cuda') | |
| return model, preprocess | |
| def remove_background(img, model, preprocess): | |
| input_batch = preprocess(img)[None, ...] | |
| if torch.cuda.is_available(): | |
| input_batch = input_batch.to('cuda') | |
| with torch.no_grad(): | |
| output = model(input_batch)['out'][0] | |
| output_predictions = torch.nn.functional.softmax(output, dim=0) | |
| output_predictions = (output_predictions > 0.98).float() | |
| img.putalpha(255) | |
| result_np = np.array(img) | |
| result_np[..., 3] = (1-output_predictions[0].cpu().numpy())*255 | |
| return Image.fromarray(result_np.astype('uint8')) | |
| import os | |
| def main(): | |
| model, preprocess = load_model() | |
| # fpath = 'data/parrot_2.png' | |
| path_in = "/localhome/mta122/PycharmProjects/logo_ai/final_nocherry_score/one/DRAGON/G" | |
| for fpath_file in os.listdir(path_in): | |
| # fpath = 'data/parrot_2.png' | |
| fpath = os.path.join(path_in, fpath_file) | |
| # fpath_out = fpath.split('.')[0] + '_result_rembg.png' | |
| # cmd = f'rembg i {fpath} {fpath_out}' | |
| # print(cmd) | |
| # os.system(cmd) | |
| img = Image.open(fpath) | |
| if img.size[-1] > 3: | |
| img_np = np.array(img) | |
| img_rbg = img_np[:, : ,:3] | |
| img = Image.fromarray(img_rbg) | |
| result = remove_background(img, model, preprocess) | |
| result.save(fpath.split('.')[0] + '_result_deeplab.png') | |
| print('finished') | |
| main() | |