|
import gradio as gr |
|
import torch |
|
import cv2 |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import numpy as np |
|
|
|
|
|
modelo = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k") |
|
extrator = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
|
|
|
|
id2label = { |
|
"0": "dyed-lifted-polyps", |
|
"1": "dyed-resection-margins", |
|
"2": "esophagitis", |
|
"3": "normal-cecum", |
|
"4": "normal-pylorus", |
|
"5": "normal-z-line", |
|
"6": "polyps", |
|
"7": "ulcerative-colitis" |
|
} |
|
|
|
|
|
def classificar_imagem(image): |
|
|
|
inputs = extrator(image, return_tensors="pt") |
|
outputs = modelo(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
classe_prevista = torch.argmax(logits, dim=1) |
|
rotulo_previsto = id2label[str(classe_prevista.item())] |
|
|
|
|
|
image_with_text = image.copy() |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 0.5 |
|
font_color = (255, 255, 255) |
|
font_thickness = 1 |
|
text_position = (10, 30) |
|
cv2.putText(image_with_text, f"Previsto: {rotulo_previsto}", text_position, font, font_scale, font_color, font_thickness) |
|
|
|
return image_with_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classificar_imagem, |
|
inputs="image", |
|
outputs="image", |
|
title="Classificação de Imagem com ViT", |
|
description="Carregue uma imagem e obtenha a imagem de entrada com o rótulo de previsão." |
|
) |
|
|
|
|
|
iface.launch(share=True, live=True, debug=True, app_name="classificador_de_imagem_vit") |
|
|