DHEIVER commited on
Commit
370dc3f
·
1 Parent(s): 3bb084b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -21
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
- from PIL import Image, ImageDraw, ImageFont
5
- import io
6
  import numpy as np
7
 
8
  # Carregue o modelo ViT
@@ -33,27 +32,19 @@ def classify_image(input_image):
33
  # Converter o ID da classe em rótulo usando o mapeamento id2label
34
  predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
35
 
36
- # Abrir a imagem usando PIL
37
- image = Image.fromarray(input_image.astype('uint8'))
38
 
39
- # Obter as dimensões da imagem
40
- width, height = image.size
41
-
42
- # Criar uma imagem com a faixa branca na parte inferior
43
- result_image = Image.new('RGB', (width, height + 40), color='white')
44
- result_image.paste(image, (0, 0))
45
-
46
- # Adicionar o rótulo da previsão na faixa branca
47
- draw = ImageDraw.Draw(result_image)
48
- font = ImageFont.load_default()
49
  text = f'Previsão: {predicted_class_label}'
50
- text_width, text_height = draw.textsize(text, font=font)
51
- x = (width - text_width) // 2
52
- y = height + 10
53
- draw.text((x, y), text, fill='black', font=font)
54
 
55
- # Converter a imagem resultante de volta para numpy
56
- result_image = np.array(result_image)
57
  return result_image
58
 
59
  # Criar uma interface Gradio
@@ -62,7 +53,7 @@ interface = gr.Interface(
62
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
63
  outputs=gr.outputs.Image(type="numpy", label="Resultado"),
64
  title="Classificador de Imagem ViT",
65
- description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na parte inferior da imagem de saída em uma faixa branca."
66
  )
67
 
68
  # Iniciar a aplicação Gradio
 
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
+ import cv2
 
5
  import numpy as np
6
 
7
  # Carregue o modelo ViT
 
32
  # Converter o ID da classe em rótulo usando o mapeamento id2label
33
  predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
34
 
35
+ # Converter a imagem de numpy para BGR (formato OpenCV)
36
+ input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
37
 
38
+ # Adicionar o rótulo da previsão na imagem
39
+ font = cv2.FONT_HERSHEY_SIMPLEX
 
 
 
 
 
 
 
 
40
  text = f'Previsão: {predicted_class_label}'
41
+ text_size = cv2.getTextSize(text, font, 0.5, 1)[0]
42
+ text_x = (input_image_bgr.shape[1] - text_size[0]) // 2
43
+ text_y = input_image_bgr.shape[0] - 10
44
+ cv2.putText(input_image_bgr, text, (text_x, text_y), font, 0.5, (255, 255, 255), 1)
45
 
46
+ # Converter a imagem resultante de volta para RGB (formato Pillow)
47
+ result_image = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)
48
  return result_image
49
 
50
  # Criar uma interface Gradio
 
53
  inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
54
  outputs=gr.outputs.Image(type="numpy", label="Resultado"),
55
  title="Classificador de Imagem ViT",
56
+ description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT). O rótulo da previsão está na imagem de saída."
57
  )
58
 
59
  # Iniciar a aplicação Gradio