yeftakun commited on
Commit
ceb898d
·
verified ·
1 Parent(s): b9ed281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -62
app.py CHANGED
@@ -1,59 +1,13 @@
1
- from flask import Flask, request, jsonify, url_for
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import requests
5
- import threading
6
- import gradio as gr
7
-
8
- # Initialize the Flask app
9
- app = Flask(__name__)
10
-
11
- # Load the processor and model outside of the route to avoid reloading it with each request
12
- processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
13
- model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
14
-
15
- @app.route('/classify', methods=['POST'])
16
- def classify_image():
17
- try:
18
- # Get the image URL from the POST request
19
- data = request.get_json()
20
- image_url = data.get('image_url')
21
-
22
- if not image_url:
23
- return jsonify({"error": "Image URL not provided"}), 400
24
-
25
- # Fetch the image from the URL
26
- image = Image.open(requests.get(image_url, stream=True).raw)
27
-
28
- # Preprocess the image
29
- inputs = processor(images=image, return_tensors="pt")
30
-
31
- # Run the image through the model
32
- outputs = model(**inputs)
33
- logits = outputs.logits
34
-
35
- # Get the predicted class
36
- predicted_class_idx = logits.argmax(-1).item()
37
- predicted_class = model.config.id2label[predicted_class_idx]
38
-
39
- # Return the classification result
40
- return jsonify({
41
- "image_url": image_url,
42
- "predicted_class": predicted_class
43
- })
44
-
45
- except Exception as e:
46
- return jsonify({"error": str(e)}), 500
47
 
48
- # Function to run the Flask app in a separate thread
49
- def run_flask():
50
- app.run(port=5000, debug=False, use_reloader=False)
51
 
52
- # Launch Flask in a separate thread
53
- flask_thread = threading.Thread(target=run_flask)
54
- flask_thread.start()
55
-
56
- # Gradio interface
57
  def predict_image(image_url):
58
  try:
59
  # Load image from URL
@@ -72,18 +26,13 @@ def predict_image(image_url):
72
  except Exception as e:
73
  return str(e)
74
 
75
- # Construct API endpoint URL
76
- api_url = "http://127.0.0.1:5000/classify"
77
-
78
- # Create Gradio interface with API info
79
  iface = gr.Interface(
80
  fn=predict_image,
81
- inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
82
- outputs=gr.Textbox(label="Predicted Class"),
83
- title="NSFW Image Detection",
84
- description=f"You can get your image classification by sending an API request to: {api_url}. Example:\n"
85
- f"curl -X POST {api_url} -H 'Content-Type: application/json' -d '{{\"image_url\": \"YOUR_IMAGE_URL\"}}'"
86
  )
87
 
88
- # Launch Gradio interface
89
- iface.launch()
 
1
+ import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Load the model and processor
7
+ processor = ViTImageProcessor.from_pretrained('yeftakun/vit-base-nsfw-detector')
8
+ model = AutoModelForImageClassification.from_pretrained('yeftakun/vit-base-nsfw-detector')
9
 
10
+ # Define prediction function
 
 
 
 
11
  def predict_image(image_url):
12
  try:
13
  # Load image from URL
 
26
  except Exception as e:
27
  return str(e)
28
 
29
+ # Create Gradio interface
 
 
 
30
  iface = gr.Interface(
31
  fn=predict_image,
32
+ inputs=gr.inputs.Textbox(label="Image URL"),
33
+ outputs=gr.outputs.Textbox(label="Predicted Class"),
34
+ title="NSFW Image Detection"
 
 
35
  )
36
 
37
+ # Launch the interface
38
+ iface.launch()