File size: 1,531 Bytes
cb531f9 ad7d1cf cb531f9 ad7d1cf cb531f9 f498967 cb531f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
model_name = 'e1010101/vit-384-tongue-image'
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-384")
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=3,
problem_type="multi_label_classification",
ignore_mismatched_sizes=True,
id2label={0: 'Crack', 1: 'Red-Dots', 2: 'Toothmark'},
label2id={'Crack': 0, 'Red-Dots': 1, 'Toothmark': 2}
)
def classify_image(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid for multi-label classification
probs = torch.sigmoid(logits)[0].numpy()
# Get label names
labels = model.config.id2label.values()
# Create a dictionary of labels and probabilities
result = {label: float(prob) for label, prob in zip(labels, probs)}
# Sort results by probability
result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
return result
interface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"),
gr.Slider(minimum=0, maximum=1, value=0.5, label="Probability Threshold")
],
outputs=gr.Label(num_top_classes=None),
title="Multi-Label Image Classification",
description="Upload an image to get classification results."
)
if __name__ == "__main__":
interface.launch() |