yeftakun commited on
Commit
fb6a0a0
·
verified ·
1 Parent(s): b2a8fdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -39
app.py CHANGED
@@ -1,17 +1,18 @@
1
  import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
- import io
5
  import requests
6
- from flask import Flask, request, jsonify
7
 
8
  # Load the model and processor
9
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
10
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
11
 
12
  # Define prediction function
13
- def predict_image(image):
14
  try:
 
 
 
15
  # Process the image and make prediction
16
  inputs = processor(images=image, return_tensors="pt")
17
  outputs = model(**inputs)
@@ -28,43 +29,10 @@ def predict_image(image):
28
  # Create Gradio interface
29
  iface = gr.Interface(
30
  fn=predict_image,
31
- inputs=gr.Image(type="pil", label="Upload Image"),
32
  outputs=gr.Textbox(label="Predicted Class"),
33
  title="NSFW Image Classifier"
34
  )
35
 
36
- # Launch the Gradio interface
37
- iface.launch()
38
-
39
- # Flask app for API endpoint
40
- app = Flask(__name__)
41
-
42
- @app.route('/predict', methods=['POST'])
43
- def predict():
44
- if 'file' not in request.files:
45
- return jsonify({'error': 'No file part'}), 400
46
-
47
- file = request.files['file']
48
- if file.filename == '':
49
- return jsonify({'error': 'No selected file'}), 400
50
-
51
- try:
52
- # Load image from the uploaded file
53
- image = Image.open(file.stream)
54
-
55
- # Process the image and make prediction
56
- inputs = processor(images=image, return_tensors="pt")
57
- outputs = model(**inputs)
58
- logits = outputs.logits
59
-
60
- # Get predicted class
61
- predicted_class_idx = logits.argmax(-1).item()
62
- predicted_label = model.config.id2label[predicted_class_idx]
63
-
64
- return jsonify({'predicted_class': predicted_label})
65
- except Exception as e:
66
- return jsonify({'error': str(e)}), 500
67
-
68
- # Run Flask app
69
- if __name__ == '__main__':
70
- app.run(port=5000)
 
1
  import gradio as gr
2
  from transformers import ViTImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
 
4
  import requests
 
5
 
6
  # Load the model and processor
7
  processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
8
  model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
9
 
10
  # Define prediction function
11
+ def predict_image(image_url):
12
  try:
13
+ # Load image from URL
14
+ image = Image.open(requests.get(image_url, stream=True).raw)
15
+
16
  # Process the image and make prediction
17
  inputs = processor(images=image, return_tensors="pt")
18
  outputs = model(**inputs)
 
29
  # Create Gradio interface
30
  iface = gr.Interface(
31
  fn=predict_image,
32
+ inputs=gr.Textbox(label="Image URL", placeholder="Enter image URL here"),
33
  outputs=gr.Textbox(label="Predicted Class"),
34
  title="NSFW Image Classifier"
35
  )
36
 
37
+ # Launch the interface
38
+ iface.launch()