DHEIVER commited on
Commit
ccb8223
·
1 Parent(s): be3f9ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -1,6 +1,8 @@
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
3
  import gradio as gr
 
 
4
 
5
  # Load the pretrained ViT model and feature extractor
6
  path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
@@ -13,18 +15,38 @@ classifier = VisionClassifierInference(
13
  model=model,
14
  )
15
 
16
- # Define a Gradio interface
17
- def classify_image(img):
 
18
  label = classifier.predict(img_path=img)
19
- return label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  iface = gr.Interface(
22
- fn=classify_image,
23
  inputs=gr.inputs.Image(),
24
- outputs=gr.outputs.Textbox(),
25
  live=True,
26
- title="ViT Image Classifier",
27
- description="Upload an image for classification.",
28
  )
29
 
30
  if __name__ == "__main__":
 
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
3
  import gradio as gr
4
+ import cv2
5
+ import numpy as np
6
 
7
  # Load the pretrained ViT model and feature extractor
8
  path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
 
15
  model=model,
16
  )
17
 
18
+ # Define a function to classify and overlay the label on the image
19
+ def classify_image_with_overlay(img):
20
+ # Predict the label
21
  label = classifier.predict(img_path=img)
22
+
23
+ # Load the image using OpenCV
24
+ image = cv2.imread(img)
25
+
26
+ # Add a white rectangle for the label
27
+ font = cv2.FONT_HERSHEY_SIMPLEX
28
+ org = (10, 30)
29
+ font_scale = 1
30
+ color = (255, 255, 255) # White color
31
+ thickness = 2
32
+ text_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
33
+ cv2.rectangle(image, (org[0] - 10, org[1] - text_size[1] - 10), (org[0] + text_size[0], org[1]), color, cv2.FILLED)
34
+
35
+ # Put the label text on the white rectangle
36
+ cv2.putText(image, label, org, font, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
37
+
38
+ # Convert the image to RGB format for Gradio
39
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
+
41
+ return image_rgb
42
 
43
  iface = gr.Interface(
44
+ fn=classify_image_with_overlay,
45
  inputs=gr.inputs.Image(),
46
+ outputs=gr.outputs.Image(),
47
  live=True,
48
+ title="ViT Image Classifier with Overlay",
49
+ description="Upload an image for classification with label overlay.",
50
  )
51
 
52
  if __name__ == "__main__":