DHEIVER commited on
Commit
9da5b91
·
1 Parent(s): bc44ddb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -49
app.py CHANGED
@@ -1,57 +1,31 @@
1
- import gradio as gr
2
- import torch
3
- import cv2
4
  from transformers import ViTFeatureExtractor, ViTForImageClassification
5
- import numpy as np
 
6
 
7
- # Carregar o modelo
8
- modelo = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
9
- extrator = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
 
10
 
11
- # Dicionário de mapeamento de índice de classe para rótulo
12
- id2label = {
13
- "0": "dyed-lifted-polyps",
14
- "1": "dyed-resection-margins",
15
- "2": "esophagitis",
16
- "3": "normal-cecum",
17
- "4": "normal-pylorus",
18
- "5": "normal-z-line",
19
- "6": "polyps",
20
- "7": "ulcerative-colitis"
21
- }
22
 
23
- # Função para classificar a imagem e adicionar rótulo de previsão
24
- def classificar_imagem(image):
25
- # Realizar a inferência usando o modelo
26
- inputs = extrator(image, return_tensors="pt")
27
- outputs = modelo(**inputs)
28
- logits = outputs.logits
29
-
30
- # Obter a classe prevista
31
- classe_prevista = torch.argmax(logits, dim=1)
32
- rotulo_previsto = id2label[str(classe_prevista.item())]
33
-
34
- # Adicionar o rótulo de previsão à imagem
35
- image_with_text = image.copy()
36
- font = cv2.FONT_HERSHEY_SIMPLEX
37
- font_scale = 0.5
38
- font_color = (255, 255, 255)
39
- font_thickness = 1
40
- text_position = (10, 30)
41
- cv2.putText(image_with_text, f"Previsto: {rotulo_previsto}", text_position, font, font_scale, font_color, font_thickness)
42
-
43
- return image_with_text
44
 
45
- # Configurar a interface Gradio com uma altura de saída maior
46
  iface = gr.Interface(
47
- fn=classificar_imagem, # Esta função é chamada quando uma imagem é carregada na interface
48
- inputs="image",
49
- outputs="image",
50
- title="Classificação de Imagem com ViT",
51
- description="Carregue uma imagem e obtenha a imagem de entrada com o rótulo de previsão.",
52
- output_height=600 # Defina a altura desejada para a saída
53
  )
54
 
55
-
56
- # Lançar a interface Gradio sem a opção 'live'
57
- iface.launch(share=False)
 
 
 
 
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
+ from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
3
+ import gradio as gr
4
 
5
+ # Load the pretrained ViT model and feature extractor
6
+ path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained(path)
8
+ model = ViTForImageClassification.from_pretrained(path)
9
 
10
+ # Create a VisionClassifierInference instance
11
+ classifier = VisionClassifierInference(
12
+ feature_extractor=feature_extractor,
13
+ model=model,
14
+ )
 
 
 
 
 
 
15
 
16
+ # Define a Gradio interface
17
+ def classify_image(img):
18
+ label = classifier.predict(img_path=img)
19
+ return label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
21
  iface = gr.Interface(
22
+ fn=classify_image,
23
+ inputs=gr.inputs.Image(),
24
+ outputs=gr.outputs.Textbox(),
25
+ live=True,
26
+ title="ViT Image Classifier",
27
+ description="Upload an image for classification.",
28
  )
29
 
30
+ if __name__ == "__main__":
31
+ iface.launch()