Spaces:
Runtime error
Runtime error
File size: 933 Bytes
21bf7d6 1513566 21bf7d6 1513566 21bf7d6 1513566 21bf7d6 1513566 21bf7d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import gradio as gr
from PIL import Image
import torchvision
import torch
# load model
MODELS_TYPE = ["ModelA", "ModelB", "ModelC"]
def predict(input_image, model_name):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# transform image to torch and do preprocessing
torch_image = torchvision.transforms.ToTensor()(pil_image)
# model predict
prediction = torch.rand(torch_image.shape)
# transform torch to image
predicted_pil_image = torchvision.transforms.ToPILImage()(prediction)
# return correct image
return predicted_pil_image
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(shape=(512,512)),
gr.inputs.Radio(MODELS_TYPE)
],
outputs=gr.Image(shape=(512,512)),
examples=[
["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
],
title="DTM Estimation",
description="This demo predict a DTM..."
)
iface.launch() |