Staticaliza commited on
Commit
fde98ce
·
verified ·
1 Parent(s): 4df851c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -7,7 +7,6 @@ import uuid
7
  import os
8
 
9
  from diffusers import StableDiffusionXLPipeline, StableDiffusion3Pipeline
10
- from transformers import AutoModelForImageClassification, ViTImageProcessor
11
  from PIL import Image
12
 
13
  # Pre-Initialize
@@ -17,6 +16,8 @@ if DEVICE == "auto":
17
  print(f"[SYSTEM] | Using {DEVICE} type compute device.")
18
 
19
  # Variables
 
 
20
  MAX_SEED = 9007199254740991
21
  DEFAULT_INPUT = ""
22
  DEFAULT_NEGATIVE_INPUT = "(bad, ugly, amputation, abstract, blur, blurry, deformed, distorted, disfigured, disconnected, mutation, mutated, low quality, lowres), unfinished, title, text, signature, watermark, (limbs, legs, feet, arms, hands), (porn, nude, naked, nsfw)"
@@ -24,6 +25,8 @@ DEFAULT_MODEL = "Default"
24
  DEFAULT_HEIGHT = 1024
25
  DEFAULT_WIDTH = 1024
26
 
 
 
27
  css = '''
28
  .gradio-container{max-width: 560px !important}
29
  h1{text-align:center}
@@ -32,9 +35,6 @@ footer {
32
  }
33
  '''
34
 
35
- repo_nsfw_classifier = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
36
- processor_nsfw_classifier = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
37
-
38
  repo_default = StableDiffusionXLPipeline.from_pretrained("fluently/Fluently-XL-Final", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
39
  #repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
40
  #repo_default.set_adapters(["base"], adapter_weights=[0.7])
@@ -65,6 +65,12 @@ def get_seed(seed):
65
  else:
66
  return random.randint(0, MAX_SEED)
67
 
 
 
 
 
 
 
68
  @spaces.GPU(duration=60)
69
  def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATIVE_INPUT, model=DEFAULT_MODEL, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
70
 
@@ -100,7 +106,6 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
100
 
101
  print(steps, guidance)
102
 
103
- repo_nsfw_classifier.to(DEVICE)
104
  repo.to(DEVICE)
105
 
106
  parameters = {
@@ -120,11 +125,9 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
120
 
121
  print(image_paths)
122
 
123
- nsfw_prediction = repo_nsfw_classifier(**processor_nsfw_classifier(images=Image.open(image_paths[0]), return_tensors="pt")).logits
124
 
125
- print(nsfw_prediction.argmax(-1))
126
- print(nsfw_prediction.argmax(-1).item())
127
- print(repo_nsfw_classifier.config.id2label[nsfw_prediction])
128
 
129
  return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}
130
 
 
7
  import os
8
 
9
  from diffusers import StableDiffusionXLPipeline, StableDiffusion3Pipeline
 
10
  from PIL import Image
11
 
12
  # Pre-Initialize
 
16
  print(f"[SYSTEM] | Using {DEVICE} type compute device.")
17
 
18
  # Variables
19
+ HF_TOKEN = os.environ.get("HF_TOKEN")
20
+
21
  MAX_SEED = 9007199254740991
22
  DEFAULT_INPUT = ""
23
  DEFAULT_NEGATIVE_INPUT = "(bad, ugly, amputation, abstract, blur, blurry, deformed, distorted, disfigured, disconnected, mutation, mutated, low quality, lowres), unfinished, title, text, signature, watermark, (limbs, legs, feet, arms, hands), (porn, nude, naked, nsfw)"
 
25
  DEFAULT_HEIGHT = 1024
26
  DEFAULT_WIDTH = 1024
27
 
28
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {HF_TOKEN}" }
29
+
30
  css = '''
31
  .gradio-container{max-width: 560px !important}
32
  h1{text-align:center}
 
35
  }
36
  '''
37
 
 
 
 
38
  repo_default = StableDiffusionXLPipeline.from_pretrained("fluently/Fluently-XL-Final", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
39
  #repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
40
  #repo_default.set_adapters(["base"], adapter_weights=[0.7])
 
65
  else:
66
  return random.randint(0, MAX_SEED)
67
 
68
+ def api_classification_request(url, filename, headers):
69
+ with open(filename, "rb") as file:
70
+ data = file.read()
71
+ response = requests.request("POST", url, headers=headers or {}, data=data)
72
+ return json.loads(response.content.decode("utf-8"))
73
+
74
  @spaces.GPU(duration=60)
75
  def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATIVE_INPUT, model=DEFAULT_MODEL, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None):
76
 
 
106
 
107
  print(steps, guidance)
108
 
 
109
  repo.to(DEVICE)
110
 
111
  parameters = {
 
125
 
126
  print(image_paths)
127
 
128
+ nsfw_prediction = api_classification_request("https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection", image_paths[0], headers)
129
 
130
+ print(nsfw_prediction)
 
 
131
 
132
  return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}
133