File size: 2,326 Bytes
e352c58
 
64e5dc2
24fc76f
 
8dc799b
68c4602
21e2e72
4858eb6
64e5dc2
 
68c4602
21e2e72
 
 
 
 
 
 
 
 
 
 
 
 
64e5dc2
21e2e72
64e5dc2
21e2e72
64e5dc2
21e2e72
 
 
 
845f19f
 
daf5ee2
 
 
845f19f
daf5ee2
 
 
 
 
 
 
 
 
845f19f
 
 
c26ef06
21e2e72
e352c58
64e5dc2
21e2e72
0666586
21e2e72
 
70a751e
 
21e2e72
68c4602
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import io
import numpy as np 

# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Mapeamento de classe ID para rótulo
id2label = {
    "0": "dyed-lifted-polyps",
    "1": "dyed-resection-margins",
    "2": "esophagitis",
    "3": "normal-cecum",
    "4": "normal-pylorus",
    "5": "normal-z-line",
    "6": "polyps",
    "7": "ulcerative-colitis"
}

# Função para classificar a imagem
def classify_image(input_image):
    # Pré-processar a imagem usando o extrator de características
    inputs = feature_extractor(input_image, return_tensors="pt")
    # Realizar inferência com o modelo
    outputs = model(**inputs)
    # Obter a classe prevista
    predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
    # Converter o ID da classe em rótulo usando o mapeamento id2label
    predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
    # Abrir a imagem usando PIL
    image = Image.fromarray(input_image.astype('uint8'))
    
    # Calcular as coordenadas para o texto no centro da imagem
    width, height = image.size
    font = ImageFont.load_default()
    text = f'Previsão: {predicted_class_label}'
    text_width, text_height = draw.textsize(text, font)
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    
    # Criar uma imagem com o rótulo de previsão sobreposta no centro
    draw = ImageDraw.Draw(image)
    draw.text((x, y), text, fill='white', font=font)
    
    # Converter a imagem resultante de volta para numpy
    result_image = np.array(image)
    return result_image

# Criar uma interface Gradio
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
    outputs=gr.outputs.Image(type="numpy", label="Previsão"),  # Definir o tipo de saída como "numpy"
    title="Classificador de Imagem ViT",
    description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
)

# Iniciar a aplicação Gradio
interface.launch()