DHEIVER's picture
Update app.py
98d76db
raw
history blame
2.33 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# Mapeamento de classe ID para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem
def classify_image(input_image):
# Pré-processar a imagem usando o extrator de características
inputs = feature_extractor(input_image, return_tensors="pt")
# Realizar inferência com o modelo
outputs = model(**inputs)
# Obter a classe prevista
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
# Converter o ID da classe em rótulo usando o mapeamento id2label
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
# Abrir a imagem usando PIL
image = Image.fromarray(input_image.astype('uint8'))
# Criar uma imagem com o rótulo de previsão sobreposta no centro
draw = ImageDraw.Draw(image)
width, height = image.size
font = ImageFont.load_default()
text = f'Previsão: {predicted_class_label}'
# Obter o tamanho do texto
text_width, text_height = draw.textsize(text, font=font)
x = (width - text_width) // 2
y = (height - text_height) // 2
draw.text((x, y), text, fill='white', font=font)
# Converter a imagem resultante de volta para numpy
result_image = np.array(image)
return result_image
# ...
# Criar uma interface Gradio
interface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
outputs=gr.outputs.Image(type="numpy", label="Previsão"), # Definir o tipo de saída como "numpy"
title="Classificador de Imagem ViT",
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
)
# Iniciar a aplicação Gradio
interface.launch()