Spaces:
Runtime error
Runtime error
File size: 1,915 Bytes
8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 b4eade4 8c753d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
from PIL import Image
from collections import OrderedDict
import torch
from models.model import GLPDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import numpy as np
# load model
DEVICE='cpu'
def load_mde_model(path):
model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE)
model_weight = torch.load(path, map_location=torch.device('cpu'))
model_weight = model_weight['model_state_dict']
if 'module' in next(iter(model_weight.items()))[0]:
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
model.load_state_dict(model_weight)
model.eval()
return model
model = load_mde_model('best_model.ckpt')
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
def predict(input_image):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# transform image to torch and do preprocessing
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0)
# model predict
with torch.no_grad():
output_patch = model(torch_img)
# transform torch to image
predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy()
# return correct image
fig, ax = plt.subplots()
im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image))
plt.colorbar(im, ax=ax)
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
iface = gr.Interface(
fn=predict,
inputs=gr.Image(shape=(512,512)),
outputs=gr.Image(shape=(512,512)),
examples=[
["demo_imgs/fake.jpg"] # use real image
],
title="DTM Estimation",
description="This demo predict a DTM..."
)
iface.launch() |