|
import gradio as gr |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference |
|
|
|
|
|
path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy" |
|
classifier = VisionClassifierInference( |
|
feature_extractor=ViTFeatureExtractor.from_pretrained(path), |
|
model=ViTForImageClassification.from_pretrained(path), |
|
) |
|
|
|
|
|
def classify_image(image_file): |
|
"""Classify an image using a pre-trained ViT model.""" |
|
label = classifier.predict(img_path=image_file.name) |
|
|
|
confidence = classifier.predict_proba(img_path=image_file.name)[0][label] |
|
return f"Predicted class: {label} (confidence: {confidence:.2f})" |
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.inputs.Image(type="filepath", label="Upload an image"), |
|
outputs="text", |
|
title="Image Classifier", |
|
description="Classify images using a pre-trained ViT model", |
|
) |
|
|
|
|
|
iface.launch() |
|
|