DHEIVER commited on
Commit
3bb084b
·
1 Parent(s): d7a8d78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
 
 
4
  import numpy as np
5
 
6
  # Carregue o modelo ViT
@@ -8,6 +10,18 @@ model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
8
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Função para classificar a imagem
12
  def classify_image(input_image):
13
  # Pré-processar a imagem usando o extrator de características
@@ -16,19 +30,31 @@ def classify_image(input_image):
16
  outputs = model(**inputs)
17
  # Obter a classe prevista
18
  predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
19
- # Mapeamento de classe ID para rótulo
20
- id2label = {
21
- 0: "dyed-lifted-polyps",
22
- 1: "dyed-resection-margins",
23
- 2: "esophagitis",
24
- 3: "normal-cecum",
25
- 4: "normal-pylorus",
26
- 5: "normal-z-line",
27
- 6: "polyps",
28
- 7: "ulcerative-colitis"
29
- }
30
- predicted_class_label = id2label.get(predicted_class_id, "Desconhecido")
31
- return input_image
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Criar uma interface Gradio
34
  interface = gr.Interface(
@@ -36,7 +62,7 @@ interface = gr.Interface(
36
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
37
  outputs=gr.outputs.Image(type="numpy", label="Resultado"),
38
  title="Classificador de Imagem ViT",
39
- description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). Nenhuma sobreposição de texto é aplicada."
40
  )
41
 
42
  # Iniciar a aplicação Gradio
 
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import io
6
  import numpy as np
7
 
8
  # Carregue o modelo ViT
 
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
11
  model = ViTForImageClassification.from_pretrained(model_name)
12
 
13
+ # Mapeamento de classe ID para rótulo
14
+ id2label = {
15
+ "0": "dyed-lifted-polyps",
16
+ "1": "dyed-resection-margins",
17
+ "2": "esophagitis",
18
+ "3": "normal-cecum",
19
+ "4": "normal-pylorus",
20
+ "5": "normal-z-line",
21
+ "6": "polyps",
22
+ "7": "ulcerative-colitis"
23
+ }
24
+
25
  # Função para classificar a imagem
26
  def classify_image(input_image):
27
  # Pré-processar a imagem usando o extrator de características
 
30
  outputs = model(**inputs)
31
  # Obter a classe prevista
32
  predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
33
+ # Converter o ID da classe em rótulo usando o mapeamento id2label
34
+ predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
35
+
36
+ # Abrir a imagem usando PIL
37
+ image = Image.fromarray(input_image.astype('uint8'))
38
+
39
+ # Obter as dimensões da imagem
40
+ width, height = image.size
41
+
42
+ # Criar uma imagem com a faixa branca na parte inferior
43
+ result_image = Image.new('RGB', (width, height + 40), color='white')
44
+ result_image.paste(image, (0, 0))
45
+
46
+ # Adicionar o rótulo da previsão na faixa branca
47
+ draw = ImageDraw.Draw(result_image)
48
+ font = ImageFont.load_default()
49
+ text = f'Previsão: {predicted_class_label}'
50
+ text_width, text_height = draw.textsize(text, font=font)
51
+ x = (width - text_width) // 2
52
+ y = height + 10
53
+ draw.text((x, y), text, fill='black', font=font)
54
+
55
+ # Converter a imagem resultante de volta para numpy
56
+ result_image = np.array(result_image)
57
+ return result_image
58
 
59
  # Criar uma interface Gradio
60
  interface = gr.Interface(
 
62
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
63
  outputs=gr.outputs.Image(type="numpy", label="Resultado"),
64
  title="Classificador de Imagem ViT",
65
+ description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
66
  )
67
 
68
  # Iniciar a aplicação Gradio