DHEIVER's picture
Update app.py
bc44ddb
raw
history blame
1.93 kB
import gradio as gr
import torch
import cv2
from transformers import ViTFeatureExtractor, ViTForImageClassification
import numpy as np
# Carregar o modelo
modelo = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
extrator = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Dicionário de mapeamento de índice de classe para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem e adicionar rótulo de previsão
def classificar_imagem(image):
# Realizar a inferência usando o modelo
inputs = extrator(image, return_tensors="pt")
outputs = modelo(**inputs)
logits = outputs.logits
# Obter a classe prevista
classe_prevista = torch.argmax(logits, dim=1)
rotulo_previsto = id2label[str(classe_prevista.item())]
# Adicionar o rótulo de previsão à imagem
image_with_text = image.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_color = (255, 255, 255)
font_thickness = 1
text_position = (10, 30)
cv2.putText(image_with_text, f"Previsto: {rotulo_previsto}", text_position, font, font_scale, font_color, font_thickness)
return image_with_text
# Configurar a interface Gradio com uma altura de saída maior
iface = gr.Interface(
fn=classificar_imagem, # Esta função é chamada quando uma imagem é carregada na interface
inputs="image",
outputs="image",
title="Classificação de Imagem com ViT",
description="Carregue uma imagem e obtenha a imagem de entrada com o rótulo de previsão.",
output_height=600 # Defina a altura desejada para a saída
)
# Lançar a interface Gradio sem a opção 'live'
iface.launch(share=False)