File size: 2,333 Bytes
e352c58
 
64e5dc2
a0eea21
24fc76f
8dc799b
68c4602
21e2e72
4858eb6
64e5dc2
 
68c4602
21e2e72
 
 
 
 
 
 
 
 
 
 
 
 
64e5dc2
21e2e72
64e5dc2
21e2e72
64e5dc2
21e2e72
 
 
 
845f19f
 
daf5ee2
07c6b10
98d76db
daf5ee2
845f19f
daf5ee2
98d76db
 
 
 
daf5ee2
 
 
 
845f19f
 
 
98d76db
 
c26ef06
21e2e72
e352c58
64e5dc2
21e2e72
0666586
21e2e72
 
70a751e
 
21e2e72
68c4602
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image, ImageDraw, ImageFont
import io
import numpy as np 

# Carregue o modelo ViT
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Mapeamento de classe ID para rótulo
id2label = {
    "0": "dyed-lifted-polyps",
    "1": "dyed-resection-margins",
    "2": "esophagitis",
    "3": "normal-cecum",
    "4": "normal-pylorus",
    "5": "normal-z-line",
    "6": "polyps",
    "7": "ulcerative-colitis"
}

# Função para classificar a imagem
def classify_image(input_image):
    # Pré-processar a imagem usando o extrator de características
    inputs = feature_extractor(input_image, return_tensors="pt")
    # Realizar inferência com o modelo
    outputs = model(**inputs)
    # Obter a classe prevista
    predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
    # Converter o ID da classe em rótulo usando o mapeamento id2label
    predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
    # Abrir a imagem usando PIL
    image = Image.fromarray(input_image.astype('uint8'))
    
    # Criar uma imagem com o rótulo de previsão sobreposta no centro
    draw = ImageDraw.Draw(image)
    width, height = image.size
    font = ImageFont.load_default()
    text = f'Previsão: {predicted_class_label}'
    
    # Obter o tamanho do texto
    text_width, text_height = draw.textsize(text, font=font)
    
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    draw.text((x, y), text, fill='white', font=font)
    
    # Converter a imagem resultante de volta para numpy
    result_image = np.array(image)
    return result_image
# ...


# Criar uma interface Gradio
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
    outputs=gr.outputs.Image(type="numpy", label="Previsão"),  # Definir o tipo de saída como "numpy"
    title="Classificador de Imagem ViT",
    description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
)

# Iniciar a aplicação Gradio
interface.launch()