DHEIVER's picture
Update app.py
3bb084b
raw
history blame
2.54 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# Mapeamento de classe ID para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem
def classify_image(input_image):
# Pré-processar a imagem usando o extrator de características
inputs = feature_extractor(input_image, return_tensors="pt")
# Realizar inferência com o modelo
outputs = model(**inputs)
# Obter a classe prevista
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
# Converter o ID da classe em rótulo usando o mapeamento id2label
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
# Abrir a imagem usando PIL
image = Image.fromarray(input_image.astype('uint8'))
# Obter as dimensões da imagem
width, height = image.size
# Criar uma imagem com a faixa branca na parte inferior
result_image = Image.new('RGB', (width, height + 40), color='white')
result_image.paste(image, (0, 0))
# Adicionar o rótulo da previsão na faixa branca
draw = ImageDraw.Draw(result_image)
font = ImageFont.load_default()
text = f'Previsão: {predicted_class_label}'
text_width, text_height = draw.textsize(text, font=font)
x = (width - text_width) // 2
y = height + 10
draw.text((x, y), text, fill='black', font=font)
# Converter a imagem resultante de volta para numpy
result_image = np.array(result_image)
return result_image
# Criar uma interface Gradio
interface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
outputs=gr.outputs.Image(type="numpy", label="Resultado"),
title="Classificador de Imagem ViT",
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
)
# Iniciar a aplicação Gradio
interface.launch()