Spaces:
Sleeping
Sleeping
File size: 3,981 Bytes
bf601e4 fb964ec 0976e91 bf601e4 d1c29b6 3237429 cb54c63 bf601e4 7751dfb bf601e4 7751dfb 56744a5 7751dfb 7026019 bf601e4 3237429 9758099 3237429 bf601e4 7751dfb d1c29b6 0976e91 aad7f4e 0976e91 bf601e4 3237429 bf601e4 3237429 0976e91 bf601e4 0976e91 bf601e4 7751dfb bf601e4 334cac1 a5218e8 3237429 7751dfb d427ed3 7751dfb bf601e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import os
import random
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
# examples
os.system("wget -O 073.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/073.png")
os.system("wget -O 356.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/356.png")
os.system("wget -O 599.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/599.png")
os.system("wget -O 630.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/630.png")
os.system("wget -O 673.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/673.png")
os.system("wget -O 019.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/019.png")
os.system("wget -O 261.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/261.png")
os.system("wget -O 524.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/524.png")
os.system("wget -O 716.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/716.png")
os.system("wget -O 898.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/898.png")
# model-setting
MODEL_PATH="./best_model_mixto/"
device = torch.device("cpu")
preprocessor = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor()
])
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()
# inference-functions
def upscale_logits(logit_outputs, size):
"""Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
return nn.functional.interpolate(
logit_outputs,
size=size,
mode="bilinear",
align_corners=False
)
def visualize_instance_seg_mask(mask):
"""Agrega colores RGB a cada una de las clases en la mask"""
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1),
random.randint(0, 255),
random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
"""Función para generar predicciones a la escala origina"""
inputs = preprocessor(img).unsqueeze(0)
with torch.no_grad():
preds = model(inputs)["logits"]
preds_upscale = upscale_logits(preds, preds.shape[2])
predict_label = torch.argmax(preds_upscale, dim=1).to(device)
result = predict_label[0,:,:].detach().cpu().numpy()
return visualize_instance_seg_mask(result)
# demo
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="pil").style(full_width=True, height=256)],
outputs=[gr.Image().style(full_width=True, height=256)],
title="Skyguard: segmentador de glaciares de roca 🛰️ +️ 🛡️ ️",
description="Modelo de segmentación de imágenes para detectar glaciares de roca.<br> Se entrenó un modelo [nvidia/SegFormer](https://huggingface.co/nvidia/mit-b0) con _fine-tuning_ en el [rock-glacier-dataset](https://huggingface.co/datasets/alkzar90/rock-glacier-dataset)",
examples=[["073.png"], ["356.png"], ["599.png"], ["630.png"], ["673.png"],
["019.png"], ["261.png"], ["524.png"], ["716.png"], ["898.png"]],
cache_examples=False
)
demo.launch()
|