File size: 805 Bytes
8c753d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from PIL import Image
import torchvision
import torch

# load model

def predict(input_image):
    pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    # transform image to torch and do preprocessing
    torch_image = torchvision.transforms.ToTensor()(pil_image)
    # model predict
    prediction = torch.rand(torch_image.shape)
    # transform torch to image
    predicted_pil_image = torchvision.transforms.ToPILImage()(prediction)
    # return correct image
    return predicted_pil_image

iface = gr.Interface(
    fn=predict, 
    inputs=gr.Image(shape=(512,512)), 
    outputs=gr.Image(shape=(512,512)),
    examples=[
        ["demo_imgs/fake.jpg"] # use real image
    ],
    title="DTM Estimation",
    description="This demo predict a DTM..."
)
iface.launch()