|
import gradio as gr |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy" |
|
classifier = VisionClassifierInference( |
|
feature_extractor=ViTFeatureExtractor.from_pretrained(path), |
|
model=ViTForImageClassification.from_pretrained(path), |
|
) |
|
|
|
|
|
def classify_image(image_file): |
|
"""Classify an image using a pre-trained ViT model.""" |
|
label = classifier.predict(img_path=image_file.name) |
|
|
|
confidence = classifier.predict_proba(img_path=image_file.name)[0][label] |
|
|
|
|
|
image = Image.open(image_file) |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
draw.text((10, 10), f"Predicted class: {label} (confidence: {confidence:.2f})", font=font, fill=(255, 255, 255)) |
|
|
|
|
|
output_image = BytesIO() |
|
image.save(output_image, format="JPEG") |
|
output_image.seek(0) |
|
|
|
return output_image, f"Predicted class: {label} (confidence: {confidence:.2f})" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.inputs.Image(type="filepath", label="Upload an image"), |
|
outputs=[gr.outputs.Image(type="numpy"), "text"], |
|
title="Image Classifier", |
|
description="Classify images using a pre-trained ViT model", |
|
) |
|
|
|
iface.launch() |
|
|