import gradio as gr import torch import cv2 from transformers import ViTFeatureExtractor, ViTForImageClassification import numpy as np # Carregar o modelo modelo = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k") extrator = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") # Dicionário de mapeamento de índice de classe para rótulo id2label = { "0": "dyed-lifted-polyps", "1": "dyed-resection-margins", "2": "esophagitis", "3": "normal-cecum", "4": "normal-pylorus", "5": "normal-z-line", "6": "polyps", "7": "ulcerative-colitis" } # Função para classificar a imagem e adicionar rótulo de previsão def classificar_imagem(image): # Realizar a inferência usando o modelo inputs = extrator(image, return_tensors="pt") outputs = modelo(**inputs) logits = outputs.logits # Obter a classe prevista classe_prevista = torch.argmax(logits, dim=1) rotulo_previsto = id2label[str(classe_prevista.item())] # Adicionar o rótulo de previsão à imagem image_with_text = image.copy() font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 font_color = (255, 255, 255) font_thickness = 1 text_position = (10, 30) cv2.putText(image_with_text, f"Previsto: {rotulo_previsto}", text_position, font, font_scale, font_color, font_thickness) return image_with_text # Configurar a interface Gradio com uma altura de saída maior iface = gr.Interface( fn=classificar_imagem, # Esta função é chamada quando uma imagem é carregada na interface inputs="image", outputs="image", title="Classificação de Imagem com ViT", description="Carregue uma imagem e obtenha a imagem de entrada com o rótulo de previsão.", output_height=600 # Defina a altura desejada para a saída ) # Lançar a interface Gradio sem a opção 'live' iface.launch(share=False)