DHEIVER's picture
Update app.py
ccb8223
raw
history blame
1.77 kB
from transformers import ViTFeatureExtractor, ViTForImageClassification
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
import gradio as gr
import cv2
import numpy as np
# Load the pretrained ViT model and feature extractor
path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
feature_extractor = ViTFeatureExtractor.from_pretrained(path)
model = ViTForImageClassification.from_pretrained(path)
# Create a VisionClassifierInference instance
classifier = VisionClassifierInference(
feature_extractor=feature_extractor,
model=model,
)
# Define a function to classify and overlay the label on the image
def classify_image_with_overlay(img):
# Predict the label
label = classifier.predict(img_path=img)
# Load the image using OpenCV
image = cv2.imread(img)
# Add a white rectangle for the label
font = cv2.FONT_HERSHEY_SIMPLEX
org = (10, 30)
font_scale = 1
color = (255, 255, 255) # White color
thickness = 2
text_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
cv2.rectangle(image, (org[0] - 10, org[1] - text_size[1] - 10), (org[0] + text_size[0], org[1]), color, cv2.FILLED)
# Put the label text on the white rectangle
cv2.putText(image, label, org, font, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
# Convert the image to RGB format for Gradio
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image_rgb
iface = gr.Interface(
fn=classify_image_with_overlay,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Image(),
live=True,
title="ViT Image Classifier with Overlay",
description="Upload an image for classification with label overlay.",
)
if __name__ == "__main__":
iface.launch()