Vishaltiwari2019's picture
Create app.py
d240e67 verified
raw
history blame
1.26 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
# Load the pre-trained model and tokenizer
model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Function to preprocess the image
def preprocess_image(image):
image = Image.open(BytesIO(image))
image = image.resize((256, 256)) # Resize the image to match the model's input size
return np.array(image)
# Function to make predictions
def classify_image(image):
image = preprocess_image(image)
inputs = tokenizer(image, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits = outputs.logits.detach().numpy()[0]
probabilities = np.exp(logits) / np.exp(logits).sum(-1)
predicted_class = np.argmax(probabilities)
return {str(i): float(prob) for i, prob in enumerate(probabilities)}
# Create a Gradio interface
input_image = gr.inputs.Image(shape=(256, 256))
output_label = gr.outputs.Label(num_top_classes=3)
gr.Interface(classify_image, inputs=input_image, outputs=output_label).launch()