DHEIVER's picture
Create app.py
70a751e
raw
history blame
1.84 kB
import gradio as gr
import torch
import cv2
from transformers import ViTFeatureExtractor, ViTForImageClassification
import numpy as np
# Carregar o modelo
modelo = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
extrator = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Dicionário de mapeamento de índice de classe para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem e adicionar rótulo de previsão
def classificar_imagem(image):
# Realizar a inferência usando o modelo
inputs = extrator(image, return_tensors="pt")
outputs = modelo(**inputs)
logits = outputs.logits
# Obter a classe prevista
classe_prevista = torch.argmax(logits, dim=1)
rotulo_previsto = id2label[str(classe_prevista.item())]
# Adicionar o rótulo de previsão à imagem
image_with_text = image.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_color = (255, 255, 255)
font_thickness = 1
text_position = (10, 30)
cv2.putText(image_with_text, f"Previsto: {rotulo_previsto}", text_position, font, font_scale, font_color, font_thickness)
return image_with_text
# Configurar a interface Gradio
iface = gr.Interface(
fn=classificar_imagem,
inputs="image",
outputs="image",
title="Classificação de Imagem com ViT",
description="Carregue uma imagem e obtenha a imagem de entrada com o rótulo de previsão."
)
# Lançar a interface Gradio com um nome específico para o aplicativo
iface.launch(share=True, live=True, debug=True, app_name="classificador_de_imagem_vit")