ruidanwang commited on
Commit
a989319
verified
1 Parent(s): 74ab170

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -1,27 +1,23 @@
1
- # prompt: gradio image 鍒嗙被
2
- import fastai
3
- from fastai.vision import *
4
- from fastai.vision.all import load_learner,PILImage
5
  import gradio as gr
6
- from transformers import AutoModelForSequenceClassification
7
- # Load the model
 
8
 
9
- model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
10
- model = load_learner(model)
11
-
12
- # Define an image classification function
13
- def classify_image(image):
14
- img = PILImage.create(image)
15
-
16
- # Make a prediction
17
- pred_class, pred_idx, probs = model.predict(img)
18
-
19
- # Return the prediction as a dictionary
20
- return {model.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
21
 
22
  # Create the Gradio interface
23
- image_input = gr.Image()
24
- label_output = gr.Label(num_top_classes=2)
25
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
26
 
27
  # Launch the interface
 
1
+ # prompt: gradio image 鍒嗙被 not safe for work
2
+ from PIL import Image
 
 
3
  import gradio as gr
4
+ from transformers import pipeline
5
+ # Load the image classification pipeline
6
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
7
 
8
+ # Define a function to classify the image and return the results
9
+ def classify_image(img):
10
+ # Convert the Gradio image input to a PIL image
11
+ pil_image = Image.fromarray(img.astype('uint8'), 'RGB')
12
+ # Classify the image using the pipeline
13
+ results = classifier(pil_image)
14
+ # Format the results for display in Gradio
15
+ formatted_results = {result['label']: result['score'] for result in results}
16
+ return formatted_results
 
 
 
17
 
18
  # Create the Gradio interface
19
+ image_input = gr.inputs.Image(shape=(256, 256))
20
+ label_output = gr.outputs.Label(num_top_classes=3)
21
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
22
 
23
  # Launch the interface