ruidanwang's picture
Update app.py
f0e0383 verified
raw
history blame
993 Bytes
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
import gradio as gr
# Load the model and processor
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# Define a function to classify the image and return the results
def classify_image(img):
pil_image = Image.fromarray(img.astype('uint8'), 'RGB')
inputs = processor(images=pil_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
label = model.config.id2label[predicted_label]
return label
# Create the Gradio interface
image_input = gr.inputs.Image(shape=(256, 256))
label_output = gr.outputs.Label()
interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
# Launch the interface
interface.launch()