Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
|
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
# Carregue o modelo ViT
|
@@ -8,6 +10,18 @@ model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
|
8 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
9 |
model = ViTForImageClassification.from_pretrained(model_name)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Função para classificar a imagem
|
12 |
def classify_image(input_image):
|
13 |
# Pré-processar a imagem usando o extrator de características
|
@@ -16,19 +30,31 @@ def classify_image(input_image):
|
|
16 |
outputs = model(**inputs)
|
17 |
# Obter a classe prevista
|
18 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Criar uma interface Gradio
|
34 |
interface = gr.Interface(
|
@@ -36,7 +62,7 @@ interface = gr.Interface(
|
|
36 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
37 |
outputs=gr.outputs.Image(type="numpy", label="Resultado"),
|
38 |
title="Classificador de Imagem ViT",
|
39 |
-
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT).
|
40 |
)
|
41 |
|
42 |
# Iniciar a aplicação Gradio
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import io
|
6 |
import numpy as np
|
7 |
|
8 |
# Carregue o modelo ViT
|
|
|
10 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
11 |
model = ViTForImageClassification.from_pretrained(model_name)
|
12 |
|
13 |
+
# Mapeamento de classe ID para rótulo
|
14 |
+
id2label = {
|
15 |
+
"0": "dyed-lifted-polyps",
|
16 |
+
"1": "dyed-resection-margins",
|
17 |
+
"2": "esophagitis",
|
18 |
+
"3": "normal-cecum",
|
19 |
+
"4": "normal-pylorus",
|
20 |
+
"5": "normal-z-line",
|
21 |
+
"6": "polyps",
|
22 |
+
"7": "ulcerative-colitis"
|
23 |
+
}
|
24 |
+
|
25 |
# Função para classificar a imagem
|
26 |
def classify_image(input_image):
|
27 |
# Pré-processar a imagem usando o extrator de características
|
|
|
30 |
outputs = model(**inputs)
|
31 |
# Obter a classe prevista
|
32 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
33 |
+
# Converter o ID da classe em rótulo usando o mapeamento id2label
|
34 |
+
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
35 |
+
|
36 |
+
# Abrir a imagem usando PIL
|
37 |
+
image = Image.fromarray(input_image.astype('uint8'))
|
38 |
+
|
39 |
+
# Obter as dimensões da imagem
|
40 |
+
width, height = image.size
|
41 |
+
|
42 |
+
# Criar uma imagem com a faixa branca na parte inferior
|
43 |
+
result_image = Image.new('RGB', (width, height + 40), color='white')
|
44 |
+
result_image.paste(image, (0, 0))
|
45 |
+
|
46 |
+
# Adicionar o rótulo da previsão na faixa branca
|
47 |
+
draw = ImageDraw.Draw(result_image)
|
48 |
+
font = ImageFont.load_default()
|
49 |
+
text = f'Previsão: {predicted_class_label}'
|
50 |
+
text_width, text_height = draw.textsize(text, font=font)
|
51 |
+
x = (width - text_width) // 2
|
52 |
+
y = height + 10
|
53 |
+
draw.text((x, y), text, fill='black', font=font)
|
54 |
+
|
55 |
+
# Converter a imagem resultante de volta para numpy
|
56 |
+
result_image = np.array(result_image)
|
57 |
+
return result_image
|
58 |
|
59 |
# Criar uma interface Gradio
|
60 |
interface = gr.Interface(
|
|
|
62 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
63 |
outputs=gr.outputs.Image(type="numpy", label="Resultado"),
|
64 |
title="Classificador de Imagem ViT",
|
65 |
+
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
|
66 |
)
|
67 |
|
68 |
# Iniciar a aplicação Gradio
|