File size: 2,536 Bytes
e352c58
 
64e5dc2
3bb084b
 
8dc799b
68c4602
21e2e72
4858eb6
64e5dc2
 
68c4602
3bb084b
 
 
 
 
 
 
 
 
 
 
 
21e2e72
64e5dc2
21e2e72
64e5dc2
21e2e72
64e5dc2
21e2e72
 
3bb084b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26ef06
21e2e72
e352c58
64e5dc2
21e2e72
f68f995
21e2e72
3bb084b
70a751e
 
21e2e72
68c4602
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np 

# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Mapeamento de classe ID para rótulo
id2label = {
    "0": "dyed-lifted-polyps",
    "1": "dyed-resection-margins",
    "2": "esophagitis",
    "3": "normal-cecum",
    "4": "normal-pylorus",
    "5": "normal-z-line",
    "6": "polyps",
    "7": "ulcerative-colitis"
}

# Função para classificar a imagem
def classify_image(input_image):
    # Pré-processar a imagem usando o extrator de características
    inputs = feature_extractor(input_image, return_tensors="pt")
    # Realizar inferência com o modelo
    outputs = model(**inputs)
    # Obter a classe prevista
    predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
    # Converter o ID da classe em rótulo usando o mapeamento id2label
    predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
    
    # Abrir a imagem usando PIL
    image = Image.fromarray(input_image.astype('uint8'))
    
    # Obter as dimensões da imagem
    width, height = image.size
    
    # Criar uma imagem com a faixa branca na parte inferior
    result_image = Image.new('RGB', (width, height + 40), color='white')
    result_image.paste(image, (0, 0))
    
    # Adicionar o rótulo da previsão na faixa branca
    draw = ImageDraw.Draw(result_image)
    font = ImageFont.load_default()
    text = f'Previsão: {predicted_class_label}'
    text_width, text_height = draw.textsize(text, font=font)
    x = (width - text_width) // 2
    y = height + 10
    draw.text((x, y), text, fill='black', font=font)
    
    # Converter a imagem resultante de volta para numpy
    result_image = np.array(result_image)
    return result_image

# Criar uma interface Gradio
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
    outputs=gr.outputs.Image(type="numpy", label="Resultado"),
    title="Classificador de Imagem ViT",
    description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
)

# Iniciar a aplicação Gradio
interface.launch()