File size: 2,536 Bytes
e352c58 64e5dc2 3bb084b 8dc799b 68c4602 21e2e72 4858eb6 64e5dc2 68c4602 3bb084b 21e2e72 64e5dc2 21e2e72 64e5dc2 21e2e72 64e5dc2 21e2e72 3bb084b c26ef06 21e2e72 e352c58 64e5dc2 21e2e72 f68f995 21e2e72 3bb084b 70a751e 21e2e72 68c4602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np
# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# Mapeamento de classe ID para rótulo
id2label = {
"0": "dyed-lifted-polyps",
"1": "dyed-resection-margins",
"2": "esophagitis",
"3": "normal-cecum",
"4": "normal-pylorus",
"5": "normal-z-line",
"6": "polyps",
"7": "ulcerative-colitis"
}
# Função para classificar a imagem
def classify_image(input_image):
# Pré-processar a imagem usando o extrator de características
inputs = feature_extractor(input_image, return_tensors="pt")
# Realizar inferência com o modelo
outputs = model(**inputs)
# Obter a classe prevista
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
# Converter o ID da classe em rótulo usando o mapeamento id2label
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
# Abrir a imagem usando PIL
image = Image.fromarray(input_image.astype('uint8'))
# Obter as dimensões da imagem
width, height = image.size
# Criar uma imagem com a faixa branca na parte inferior
result_image = Image.new('RGB', (width, height + 40), color='white')
result_image.paste(image, (0, 0))
# Adicionar o rótulo da previsão na faixa branca
draw = ImageDraw.Draw(result_image)
font = ImageFont.load_default()
text = f'Previsão: {predicted_class_label}'
text_width, text_height = draw.textsize(text, font=font)
x = (width - text_width) // 2
y = height + 10
draw.text((x, y), text, fill='black', font=font)
# Converter a imagem resultante de volta para numpy
result_image = np.array(result_image)
return result_image
# Criar uma interface Gradio
interface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
outputs=gr.outputs.Image(type="numpy", label="Resultado"),
title="Classificador de Imagem ViT",
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
)
# Iniciar a aplicação Gradio
interface.launch()
|