Staticaliza commited on
Commit
85f70fe
·
verified ·
1 Parent(s): 457213e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -33,6 +33,8 @@ footer {
33
  }
34
  '''
35
 
 
 
36
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
37
  controlnet = ControlNetModel.from_pretrained("MakiPan/controlnet-encoded-hands-130k", torch_dtype=torch.float16)
38
 
@@ -125,11 +127,15 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
125
 
126
  print(image_paths)
127
 
128
- classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")(Image.open(image_paths[0]))
 
 
 
 
129
 
130
- print(classifier)
131
 
132
- return image_paths, [f"{index['label']}: {index['score']:.3%}" for index in classifier]
133
 
134
  def cloud():
135
  print("[CLOUD] | Space maintained.")
@@ -156,9 +162,9 @@ with gr.Blocks(css=css) as main:
156
 
157
  with gr.Column():
158
  images = gr.Gallery(columns=1, label="Image")
159
- classifier = gr.Label()
160
 
161
- submit.click(generate, inputs=[input, filter_input, negative_input, model, height, width, steps, guidance, number, seed], outputs=[images, classifier], queue=False)
162
  maintain.click(cloud, inputs=[], outputs=[], queue=False)
163
 
164
  main.launch(show_api=True)
 
33
  }
34
  '''
35
 
36
+ repo_nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
37
+
38
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
39
  controlnet = ControlNetModel.from_pretrained("MakiPan/controlnet-encoded-hands-130k", torch_dtype=torch.float16)
40
 
 
127
 
128
  print(image_paths)
129
 
130
+ nsfw_prediction = repo_nsfw_classifier(Image.open(image_paths[0]))
131
+
132
+ print(nsfw_prediction)
133
+
134
+ nsfw_format = [f"{index['label']}: {index['score']:.3%}" for index in nsfw_prediction]
135
 
136
+ print(nsfw_format)
137
 
138
+ return image_paths, nsfw_format
139
 
140
  def cloud():
141
  print("[CLOUD] | Space maintained.")
 
162
 
163
  with gr.Column():
164
  images = gr.Gallery(columns=1, label="Image")
165
+ nsfw_classifier = gr.Label()
166
 
167
+ submit.click(generate, inputs=[input, filter_input, negative_input, model, height, width, steps, guidance, number, seed], outputs=[images, nsfw_classifier], queue=False)
168
  maintain.click(cloud, inputs=[], outputs=[], queue=False)
169
 
170
  main.launch(show_api=True)