import gradio as gr from PIL import Image from collections import OrderedDict import torch from models.model import GLPDepth from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg import numpy as np # load model DEVICE='cpu' def load_mde_model(path): model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) model_weight = torch.load(path, map_location=torch.device('cpu')) model_weight = model_weight['model_state_dict'] if 'module' in next(iter(model_weight.items()))[0]: model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) model.load_state_dict(model_weight) model.eval() return model model = load_mde_model('best_model.ckpt') preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) def predict(input_image): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # transform image to torch and do preprocessing torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0) # model predict with torch.no_grad(): output_patch = model(torch_img) # transform torch to image predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy() # return correct image fig, ax = plt.subplots() im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image)) plt.colorbar(im, ax=ax) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data iface = gr.Interface( fn=predict, inputs=gr.Image(shape=(512,512)), outputs=gr.Image(shape=(512,512)), examples=[ ["demo_imgs/fake.jpg"] # use real image ], title="DTM Estimation", description="This demo predict a DTM..." ) iface.launch()