DHEIVER commited on
Commit
51e2247
·
1 Parent(s): ad05d92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -40
app.py CHANGED
@@ -1,13 +1,6 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
- import cv2
5
- import numpy as np
6
-
7
- # Carregue o modelo ViT
8
- model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
9
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
10
- model = ViTForImageClassification.from_pretrained(model_name)
11
 
12
  # Mapeamento de classe ID para rótulo
13
  id2label = {
@@ -21,52 +14,30 @@ id2label = {
21
  "7": "ulcerative-colitis"
22
  }
23
 
 
 
 
 
 
24
  # Função para classificar a imagem
25
  def classify_image(input_image):
26
- # Redimensionar a imagem de entrada para ser 2x maior
27
- input_image = cv2.resize(input_image, None, fx=2, fy=2)
28
-
29
  # Pré-processar a imagem usando o extrator de características
30
  inputs = feature_extractor(input_image, return_tensors="pt")
31
  # Realizar inferência com o modelo
32
  outputs = model(**inputs)
33
  # Obter a classe prevista
34
- predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
35
- # Obter a probabilidade da classe prevista
36
- predicted_class_prob = torch.softmax(outputs.logits, dim=1)[0, predicted_class_id].item()
37
- # Converter o ID da classe em rótulo usando o mapeamento id2label
38
  predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
39
-
40
- # Converter a imagem de numpy para BGR (formato OpenCV)
41
- input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
42
-
43
- # Definir cores de borda para cada classe (aqui, cores aleatórias)
44
- class_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
45
- (255, 0, 255), (0, 255, 255), (128, 128, 128), (0, 0, 0)]
46
-
47
- # Adicionar uma borda colorida à imagem
48
- border_color = class_colors[predicted_class_id]
49
- input_image_bgr = cv2.copyMakeBorder(input_image_bgr, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=border_color)
50
-
51
- # Adicionar o rótulo da previsão na imagem
52
- font = cv2.FONT_HERSHEY_SIMPLEX
53
- text = f'Classe: {predicted_class_label} ({predicted_class_prob:.2f})'
54
- text_size = cv2.getTextSize(text, font, 0.7, 2)[0]
55
- text_x = (input_image_bgr.shape[1] - text_size[0]) // 2
56
- text_y = input_image_bgr.shape[0] - 30 # Ajuste da posição vertical
57
- cv2.putText(input_image_bgr, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
58
-
59
- # Converter a imagem resultante de volta para RGB (formato Pillow)
60
- result_image = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)
61
- return result_image
62
 
63
  # Criar uma interface Gradio
64
  interface = gr.Interface(
65
  fn=classify_image,
66
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
67
- outputs=gr.outputs.Image(type="numpy", label="Resultado"),
68
  title="Classificador de Imagem ViT",
69
- description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na imagem de saída."
70
  )
71
 
72
  # Iniciar a aplicação Gradio
 
1
  import gradio as gr
 
2
  from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ import numpy as np
 
 
 
 
 
 
4
 
5
  # Mapeamento de classe ID para rótulo
6
  id2label = {
 
14
  "7": "ulcerative-colitis"
15
  }
16
 
17
+ # Carregue o modelo ViT
18
+ model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
19
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
20
+ model = ViTForImageClassification.from_pretrained(model_name)
21
+
22
  # Função para classificar a imagem
23
  def classify_image(input_image):
 
 
 
24
  # Pré-processar a imagem usando o extrator de características
25
  inputs = feature_extractor(input_image, return_tensors="pt")
26
  # Realizar inferência com o modelo
27
  outputs = model(**inputs)
28
  # Obter a classe prevista
29
+ predicted_class_id = np.argmax(outputs.logits)
30
+ # Obter o rótulo da classe a partir do mapeamento id2label
 
 
31
  predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
32
+ return predicted_class_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Criar uma interface Gradio
35
  interface = gr.Interface(
36
  fn=classify_image,
37
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
38
+ outputs=gr.outputs.Label(num_top_classes=1),
39
  title="Classificador de Imagem ViT",
40
+ description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT).",
41
  )
42
 
43
  # Iniciar a aplicação Gradio