Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
-
import
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
# Carregue o modelo ViT
|
8 |
-
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
9 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
10 |
-
model = ViTForImageClassification.from_pretrained(model_name)
|
11 |
|
12 |
# Mapeamento de classe ID para rótulo
|
13 |
id2label = {
|
@@ -21,52 +14,30 @@ id2label = {
|
|
21 |
"7": "ulcerative-colitis"
|
22 |
}
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
# Função para classificar a imagem
|
25 |
def classify_image(input_image):
|
26 |
-
# Redimensionar a imagem de entrada para ser 2x maior
|
27 |
-
input_image = cv2.resize(input_image, None, fx=2, fy=2)
|
28 |
-
|
29 |
# Pré-processar a imagem usando o extrator de características
|
30 |
inputs = feature_extractor(input_image, return_tensors="pt")
|
31 |
# Realizar inferência com o modelo
|
32 |
outputs = model(**inputs)
|
33 |
# Obter a classe prevista
|
34 |
-
predicted_class_id =
|
35 |
-
# Obter
|
36 |
-
predicted_class_prob = torch.softmax(outputs.logits, dim=1)[0, predicted_class_id].item()
|
37 |
-
# Converter o ID da classe em rótulo usando o mapeamento id2label
|
38 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
39 |
-
|
40 |
-
# Converter a imagem de numpy para BGR (formato OpenCV)
|
41 |
-
input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
|
42 |
-
|
43 |
-
# Definir cores de borda para cada classe (aqui, cores aleatórias)
|
44 |
-
class_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
|
45 |
-
(255, 0, 255), (0, 255, 255), (128, 128, 128), (0, 0, 0)]
|
46 |
-
|
47 |
-
# Adicionar uma borda colorida à imagem
|
48 |
-
border_color = class_colors[predicted_class_id]
|
49 |
-
input_image_bgr = cv2.copyMakeBorder(input_image_bgr, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=border_color)
|
50 |
-
|
51 |
-
# Adicionar o rótulo da previsão na imagem
|
52 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
53 |
-
text = f'Classe: {predicted_class_label} ({predicted_class_prob:.2f})'
|
54 |
-
text_size = cv2.getTextSize(text, font, 0.7, 2)[0]
|
55 |
-
text_x = (input_image_bgr.shape[1] - text_size[0]) // 2
|
56 |
-
text_y = input_image_bgr.shape[0] - 30 # Ajuste da posição vertical
|
57 |
-
cv2.putText(input_image_bgr, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
58 |
-
|
59 |
-
# Converter a imagem resultante de volta para RGB (formato Pillow)
|
60 |
-
result_image = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)
|
61 |
-
return result_image
|
62 |
|
63 |
# Criar uma interface Gradio
|
64 |
interface = gr.Interface(
|
65 |
fn=classify_image,
|
66 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
67 |
-
outputs=gr.outputs.
|
68 |
title="Classificador de Imagem ViT",
|
69 |
-
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT).
|
70 |
)
|
71 |
|
72 |
# Iniciar a aplicação Gradio
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
3 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Mapeamento de classe ID para rótulo
|
6 |
id2label = {
|
|
|
14 |
"7": "ulcerative-colitis"
|
15 |
}
|
16 |
|
17 |
+
# Carregue o modelo ViT
|
18 |
+
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
19 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
20 |
+
model = ViTForImageClassification.from_pretrained(model_name)
|
21 |
+
|
22 |
# Função para classificar a imagem
|
23 |
def classify_image(input_image):
|
|
|
|
|
|
|
24 |
# Pré-processar a imagem usando o extrator de características
|
25 |
inputs = feature_extractor(input_image, return_tensors="pt")
|
26 |
# Realizar inferência com o modelo
|
27 |
outputs = model(**inputs)
|
28 |
# Obter a classe prevista
|
29 |
+
predicted_class_id = np.argmax(outputs.logits)
|
30 |
+
# Obter o rótulo da classe a partir do mapeamento id2label
|
|
|
|
|
31 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
32 |
+
return predicted_class_label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# Criar uma interface Gradio
|
35 |
interface = gr.Interface(
|
36 |
fn=classify_image,
|
37 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
38 |
+
outputs=gr.outputs.Label(num_top_classes=1),
|
39 |
title="Classificador de Imagem ViT",
|
40 |
+
description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT).",
|
41 |
)
|
42 |
|
43 |
# Iniciar a aplicação Gradio
|