File size: 933 Bytes
21bf7d6
 
 
 
 
 
1513566
21bf7d6
1513566
21bf7d6
 
 
 
 
 
 
 
 
 
 
 
1513566
 
 
 
21bf7d6
 
1513566
21bf7d6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import gradio as gr
from PIL import Image
import torchvision
import torch

# load model
MODELS_TYPE = ["ModelA", "ModelB", "ModelC"]

def predict(input_image, model_name):
    pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    # transform image to torch and do preprocessing
    torch_image = torchvision.transforms.ToTensor()(pil_image)
    # model predict
    prediction = torch.rand(torch_image.shape)
    # transform torch to image
    predicted_pil_image = torchvision.transforms.ToPILImage()(prediction)
    # return correct image
    return predicted_pil_image

iface = gr.Interface(
    fn=predict, 
    inputs=[
        gr.Image(shape=(512,512)),
         gr.inputs.Radio(MODELS_TYPE)
    ], 
    outputs=gr.Image(shape=(512,512)),
    examples=[
        ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
    ],
    title="DTM Estimation",
    description="This demo predict a DTM..."
)
iface.launch()