from flask import Flask, request, jsonify, url_for from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests import threading import gradio as gr # Initialize the Flask app app = Flask(__name__) # Load the processor and model outside of the route to avoid reloading it with each request processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') @app.route('/classify', methods=['POST']) def classify_image(): try: # Get the image URL from the POST request data = request.get_json() image_url = data.get('image_url') if not image_url: return jsonify({"error": "Image URL not provided"}), 400 # Fetch the image from the URL image = Image.open(requests.get(image_url, stream=True).raw) # Preprocess the image inputs = processor(images=image, return_tensors="pt") # Run the image through the model outputs = model(**inputs) logits = outputs.logits # Get the predicted class predicted_class_idx = logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] # Return the classification result return jsonify({ "image_url": image_url, "predicted_class": predicted_class }) except Exception as e: return jsonify({"error": str(e)}), 500 # Function to run the Flask app in a separate thread def run_flask(): app.run(port=5000, debug=False, use_reloader=False) # Launch Flask in a separate thread flask_thread = threading.Thread(target=run_flask) flask_thread.start() # Gradio interface def predict_image(image_url): try: # Load image from URL image = Image.open(requests.get(image_url, stream=True).raw) # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Construct API endpoint URL api_url = "http://127.0.0.1:5000/classify" # Create Gradio interface with API info iface = gr.Interface( fn=predict_image, inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"), outputs=gr.Textbox(label="Predicted Class"), title="NSFW Image Detection", description=f"You can get your image classification by sending an API request to: {api_url}. Example:\n" f"curl -X POST {api_url} -H 'Content-Type: application/json' -d '{{\"image_url\": \"YOUR_IMAGE_URL\"}}'" ) # Launch Gradio interface iface.launch()