import gradio as gr from transformers import ViTFeatureExtractor, ViTForImageClassification from hugsvision.inference.VisionClassifierInference import VisionClassifierInference from PIL import Image, ImageDraw, ImageFont # Load the pre-trained ViT model path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy" classifier = VisionClassifierInference( feature_extractor=ViTFeatureExtractor.from_pretrained(path), model=ViTForImageClassification.from_pretrained(path), ) def classify_image(image_file): """Classify an image using a pre-trained ViT model.""" label = classifier.predict(img_path=image_file.name) # Add a confidence score to the output confidence = classifier.predict_proba(img_path=image_file.name)[0][label] # Get the PIL Image object for the uploaded image image = Image.open(image_file) # Draw the predicted label on the image draw = ImageDraw.Draw(image) font = ImageFont.truetype("arial.ttf", 20) draw.text((10, 10), f"Predicted class: {label} (confidence: {confidence:.2f})", font=font, fill=(255, 255, 255)) # Save the modified image to a BytesIO object output_image = BytesIO() image.save(output_image, format="JPEG") output_image.seek(0) return output_image, f"Predicted class: {label} (confidence: {confidence:.2f})" iface = gr.Interface( fn=classify_image, inputs=gr.inputs.Image(type="filepath", label="Upload an image"), outputs=[gr.outputs.Image(), "text"], title="Image Classifier", description="Classify images using a pre-trained ViT model", ) # Launch the Gradio app iface.launch()