e1010101's picture
Create app.py
cb531f9 verified
raw
history blame
1.2 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageClassification
import torch
from PIL import Image
model_name = 'e1010101/vit-384-tongue-image'
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
def classify_image(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid for multi-label classification
probs = torch.sigmoid(logits)[0].numpy()
# Get label names
labels = model.config.id2label.values()
# Create a dictionary of labels and probabilities
result = {label: float(prob) for label, prob in zip(labels, probs)}
# Sort results by probability
result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
return result
interface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=None),
title="Multi-Label Image Classification",
description="Upload an image to get classification results."
)
if __name__ == "__main__":
interface.launch()