File size: 2,333 Bytes
e352c58 64e5dc2 a0eea21 24fc76f 8dc799b 68c4602 21e2e72 4858eb6 64e5dc2 68c4602 21e2e72 64e5dc2 21e2e72 64e5dc2 21e2e72 64e5dc2 21e2e72 845f19f daf5ee2 07c6b10 98d76db daf5ee2 845f19f daf5ee2 98d76db daf5ee2 845f19f 98d76db c26ef06 21e2e72 e352c58 64e5dc2 21e2e72 0666586 21e2e72 70a751e 21e2e72 68c4602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# Mapeamento de classe ID para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem
def classify_image(input_image):
# Pré-processar a imagem usando o extrator de características
inputs = feature_extractor(input_image, return_tensors="pt")
# Realizar inferência com o modelo
outputs = model(**inputs)
# Obter a classe prevista
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
# Converter o ID da classe em rótulo usando o mapeamento id2label
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
# Abrir a imagem usando PIL
image = Image.fromarray(input_image.astype('uint8'))
# Criar uma imagem com o rótulo de previsão sobreposta no centro
draw = ImageDraw.Draw(image)
width, height = image.size
font = ImageFont.load_default()
text = f'Previsão: {predicted_class_label}'
# Obter o tamanho do texto
text_width, text_height = draw.textsize(text, font=font)
x = (width - text_width) // 2
y = (height - text_height) // 2
draw.text((x, y), text, fill='white', font=font)
# Converter a imagem resultante de volta para numpy
result_image = np.array(image)
return result_image
# ...
# Criar uma interface Gradio
interface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
outputs=gr.outputs.Image(type="numpy", label="Previsão"), # Definir o tipo de saída como "numpy"
title="Classificador de Imagem ViT",
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
)
# Iniciar a aplicação Gradio
interface.launch()
|