Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
2 |
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
|
3 |
import gradio as gr
|
|
|
|
|
4 |
|
5 |
# Load the pretrained ViT model and feature extractor
|
6 |
path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
@@ -13,18 +15,38 @@ classifier = VisionClassifierInference(
|
|
13 |
model=model,
|
14 |
)
|
15 |
|
16 |
-
# Define a
|
17 |
-
def
|
|
|
18 |
label = classifier.predict(img_path=img)
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
iface = gr.Interface(
|
22 |
-
fn=
|
23 |
inputs=gr.inputs.Image(),
|
24 |
-
outputs=gr.outputs.
|
25 |
live=True,
|
26 |
-
title="ViT Image Classifier",
|
27 |
-
description="Upload an image for classification.",
|
28 |
)
|
29 |
|
30 |
if __name__ == "__main__":
|
|
|
1 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
2 |
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
|
3 |
import gradio as gr
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
|
7 |
# Load the pretrained ViT model and feature extractor
|
8 |
path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
|
|
15 |
model=model,
|
16 |
)
|
17 |
|
18 |
+
# Define a function to classify and overlay the label on the image
|
19 |
+
def classify_image_with_overlay(img):
|
20 |
+
# Predict the label
|
21 |
label = classifier.predict(img_path=img)
|
22 |
+
|
23 |
+
# Load the image using OpenCV
|
24 |
+
image = cv2.imread(img)
|
25 |
+
|
26 |
+
# Add a white rectangle for the label
|
27 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
28 |
+
org = (10, 30)
|
29 |
+
font_scale = 1
|
30 |
+
color = (255, 255, 255) # White color
|
31 |
+
thickness = 2
|
32 |
+
text_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
33 |
+
cv2.rectangle(image, (org[0] - 10, org[1] - text_size[1] - 10), (org[0] + text_size[0], org[1]), color, cv2.FILLED)
|
34 |
+
|
35 |
+
# Put the label text on the white rectangle
|
36 |
+
cv2.putText(image, label, org, font, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
37 |
+
|
38 |
+
# Convert the image to RGB format for Gradio
|
39 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
+
|
41 |
+
return image_rgb
|
42 |
|
43 |
iface = gr.Interface(
|
44 |
+
fn=classify_image_with_overlay,
|
45 |
inputs=gr.inputs.Image(),
|
46 |
+
outputs=gr.outputs.Image(),
|
47 |
live=True,
|
48 |
+
title="ViT Image Classifier with Overlay",
|
49 |
+
description="Upload an image for classification with label overlay.",
|
50 |
)
|
51 |
|
52 |
if __name__ == "__main__":
|