marta-0's picture
add files
6da6215
raw
history blame
1.8 kB
import requests
import gradio as gr
import paddle
from paddleseg.cvlibs import Config
from matting.core import predict
from matting.model import *
from matting.dataset import MattingDataset
def download_file(http_address, file_name):
r = requests.get(http_address, allow_redirects=True)
open(file_name, 'wb').write(r.content)
cfgs = ['configs/modnet/modnet_mobilenetv2.yml', 'configs/modnet/modnet_resnet50_vd.yml', 'configs/modnet/modnet_hrnet_w18.yml']
download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams', 'modnet-mobilenetv2.pdparams')
download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams', 'modnet-resnet50_vd.pdparams')
download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams', 'modnet-hrnet_w18.pdparams')
models_paths = ['modnet-mobilenetv2.pdparams', 'modnet-resnet50_vd.pdparams', 'modnet-hrnet_w18.pdparams']
def inference(image, chosen_model):
paddle.set_device('cpu')
cfg = Config(cfgs[chosen_model])
val_dataset = cfg.val_dataset
model = cfg.model
img_transforms = val_dataset.transforms
alpha_pred = predict(model,
model_path=models_paths[chosen_model],
transforms=img_transforms,
image_list=[image])
return alpha_pred
inputs = [gr.inputs.Image(label='Input Image'),
gr.inputs.Radio(['MobileNetV2', 'ResNet50_vd', 'HRNet_W18'], label='Model', type='index')]
gr.Interface(
inference,
inputs,
gr.outputs.Image(label='Output'),
title='PaddleSeg - Matting',
examples=[['images/armchair.jpg', 'MobileNetV2'],
['images/cat.jpg', 'ResNet50_vd'],
['images/plant.jpg', 'HRNet_W18']]
).launch()