Staticaliza commited on
Commit
48c2ad8
·
verified ·
1 Parent(s): ecc0861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -100,8 +100,8 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
100
 
101
  print(steps, guidance)
102
 
103
- #repo_nsfw_classifier.to(DEVICE)
104
- #processor_nsfw_classifier.to(DEVICE)
105
  repo.to(DEVICE)
106
 
107
  parameters = {
@@ -121,9 +121,10 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
121
 
122
  print(image_paths)
123
 
124
- nsfw_prediction = repo_nsfw_classifier(**processor_nsfw_classifier(images=Image.open(image_paths[0]), return_tensors="pt")).logits.argmax(-1).item()
125
 
126
- print(nsfw_prediction)
 
127
  print(repo_nsfw_classifier.config.id2label[nsfw_prediction])
128
 
129
  return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}
 
100
 
101
  print(steps, guidance)
102
 
103
+ repo_nsfw_classifier.to(DEVICE)
104
+ processor_nsfw_classifier.to(DEVICE)
105
  repo.to(DEVICE)
106
 
107
  parameters = {
 
121
 
122
  print(image_paths)
123
 
124
+ nsfw_prediction = repo_nsfw_classifier(**processor_nsfw_classifier(images=Image.open(image_paths[0]), return_tensors="pt")).logits
125
 
126
+ print(nsfw_prediction.argmax(-1))
127
+ print(nsfw_prediction.argmax(-1).item())
128
  print(repo_nsfw_classifier.config.id2label[nsfw_prediction])
129
 
130
  return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}