import numpy as np import cv2 import gradio as gr import torch from ade20k_colors import colors from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation beit_models = ['microsoft/beit-base-finetuned-ade-640-640', 'microsoft/beit-large-finetuned-ade-640-640'] models = [BeitForSemanticSegmentation.from_pretrained(m) for m in beit_models] extractors = [BeitFeatureExtractor.from_pretrained(m) for m in beit_models] def apply_colors(img): ret = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) for y in range(img.shape[0]): for x in range(img.shape[1]): ret[y,x] = colors[np.argmax(img[y,x])] return ret def inference(image, chosen_model): feature_extractor = extractors[chosen_model] model = models[chosen_model] inputs = feature_extractor(images=image, return_tensors='pt') outputs = model(**inputs) logits = outputs.logits output = torch.sigmoid(logits).detach().numpy()[0] output = np.transpose(output, (1,2,0)) output = apply_colors(output) return cv2.resize(output, image.shape[1::-1]) inputs = [gr.inputs.Image(label='Input Image'), gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')] gr.Interface( inference, inputs, gr.outputs.Image(label='Output'), title='BEiT - Semantic Segmentation', description='BEIT: BERT Pre-Training of Image Transformers', examples=[['images/armchair.jpg', 'Base'], ['images/cat.jpg', 'Base'], ['images/plant.jpg', 'Large']] ).launch()