hatmanstack commited on
Commit
4cacb08
·
1 Parent(s): 1d047e4

reworked verify image

Browse files
Files changed (3) hide show
  1. app.py +14 -8
  2. functions.py +1 -0
  3. generate.py +29 -27
app.py CHANGED
@@ -31,9 +31,15 @@ with gr.Blocks() as demo:
31
  max-width: 800px;
32
  margin: 0 auto;
33
  }
 
 
 
 
 
 
34
  </style>
35
  """)
36
- gr.Markdown("# Amazon Nova Canvas Image Generation")
37
 
38
  with gr.Tab("Text to Image"):
39
  with gr.Column():
@@ -44,7 +50,7 @@ with gr.Blocks() as demo:
44
  """)
45
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
46
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
47
- error_box = gr.Markdown(visible=False, label="Error")
48
  output = gr.Image()
49
  with gr.Accordion("Advanced Options", open=False):
50
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
@@ -66,7 +72,7 @@ with gr.Blocks() as demo:
66
  mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1)
67
  with gr.Accordion("Mask Image", open=False):
68
  mask_image = gr.Image(type='pil', label="Mask Image")
69
- error_box = gr.Markdown(visible=False, label="Error")
70
  output = gr.Image()
71
  with gr.Accordion("Advanced Options", open=False):
72
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
@@ -90,7 +96,7 @@ with gr.Blocks() as demo:
90
  mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1)
91
  with gr.Accordion("Mask Image", open=False):
92
  mask_image = gr.Image(type='pil', label="Mask Image")
93
- error_box = gr.Markdown(visible=False, label="Error")
94
  output = gr.Image()
95
  with gr.Accordion("Advanced Options", open=False):
96
  outpainting_mode = gr.Radio(choices=["DEFAULT", "PRECISE"], value="DEFAULT", label="Outpainting Mode")
@@ -109,7 +115,7 @@ with gr.Blocks() as demo:
109
  with gr.Accordion("Optional Prompt", open=False):
110
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
111
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
112
- error_box = gr.Markdown(visible=False, label="Error")
113
  output = gr.Image()
114
  with gr.Accordion("Advanced Options", open=False):
115
  similarity_strength = gr.Slider(minimum=0.2, maximum=1.0, step=0.1, value=0.7, label="Similarity Strength")
@@ -129,7 +135,7 @@ with gr.Blocks() as demo:
129
  condition_image = gr.Image(type='pil', label="Condition Image")
130
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
131
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
132
- error_box = gr.Markdown(visible=False, label="Error")
133
  output = gr.Image()
134
  with gr.Accordion("Advanced Options", open=False):
135
  control_mode = gr.Radio(choices=["CANNY_EDGE", "SEGMENTATION"], value="CANNY_EDGE", label="Control Mode")
@@ -150,7 +156,7 @@ with gr.Blocks() as demo:
150
  with gr.Accordion("Optional Prompt", open=False):
151
  prompt = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
152
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
153
- error_box = gr.Markdown(visible=False, label="Error")
154
  output = gr.Image()
155
  with gr.Accordion("Advanced Options", open=False):
156
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
@@ -164,7 +170,7 @@ with gr.Blocks() as demo:
164
  </div>
165
  """)
166
  image = gr.Image(type='pil', label="Input Image")
167
- error_box = gr.Markdown(visible=False, label="Error")
168
  output = gr.Image()
169
  gr.Button("Generate").click(background_removal, inputs=image, outputs=[output, error_box])
170
 
 
31
  max-width: 800px;
32
  margin: 0 auto;
33
  }
34
+ .center-markdown {
35
+ text-align: center !important;
36
+ display: flex !important;
37
+ justify-content: center !important;
38
+ width: 100% !important;
39
+ }
40
  </style>
41
  """)
42
+ gr.Markdown("<h1>Amazon Nova Canvas Image Generation</h1>", elem_classes="center-markdown" )
43
 
44
  with gr.Tab("Text to Image"):
45
  with gr.Column():
 
50
  """)
51
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
52
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
53
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
54
  output = gr.Image()
55
  with gr.Accordion("Advanced Options", open=False):
56
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
 
72
  mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1)
73
  with gr.Accordion("Mask Image", open=False):
74
  mask_image = gr.Image(type='pil', label="Mask Image")
75
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
76
  output = gr.Image()
77
  with gr.Accordion("Advanced Options", open=False):
78
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
 
96
  mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1)
97
  with gr.Accordion("Mask Image", open=False):
98
  mask_image = gr.Image(type='pil', label="Mask Image")
99
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
100
  output = gr.Image()
101
  with gr.Accordion("Advanced Options", open=False):
102
  outpainting_mode = gr.Radio(choices=["DEFAULT", "PRECISE"], value="DEFAULT", label="Outpainting Mode")
 
115
  with gr.Accordion("Optional Prompt", open=False):
116
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
117
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
118
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
119
  output = gr.Image()
120
  with gr.Accordion("Advanced Options", open=False):
121
  similarity_strength = gr.Slider(minimum=0.2, maximum=1.0, step=0.1, value=0.7, label="Similarity Strength")
 
135
  condition_image = gr.Image(type='pil', label="Condition Image")
136
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
137
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
138
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
139
  output = gr.Image()
140
  with gr.Accordion("Advanced Options", open=False):
141
  control_mode = gr.Radio(choices=["CANNY_EDGE", "SEGMENTATION"], value="CANNY_EDGE", label="Control Mode")
 
156
  with gr.Accordion("Optional Prompt", open=False):
157
  prompt = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=4)
158
  gr.Button("Generate Prompt").click(generate_nova_prompt, outputs=prompt)
159
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
160
  output = gr.Image()
161
  with gr.Accordion("Advanced Options", open=False):
162
  negative_text, width, height, quality, cfg_scale, seed = create_advanced_options()
 
170
  </div>
171
  """)
172
  image = gr.Image(type='pil', label="Input Image")
173
+ error_box = gr.Markdown(visible=False, label="Error", elem_classes="center-markdown")
174
  output = gr.Image()
175
  gr.Button("Generate").click(background_removal, inputs=image, outputs=[output, error_box])
176
 
functions.py CHANGED
@@ -68,6 +68,7 @@ def text_to_image(prompt, negative_text=None, height=1024, width=1024, quality="
68
 
69
  def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
70
  images = process_images(primary=image, secondary=None)
 
71
  for value in images.values():
72
  if isinstance(value, str) and "Not Appropriate" in value:
73
  return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
 
68
 
69
  def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0):
70
  images = process_images(primary=image, secondary=None)
71
+
72
  for value in images.values():
73
  if isinstance(value, str) and "Not Appropriate" in value:
74
  return None, gr.update(visible=True, value="Image <b>Not Appropriate</b>")
generate.py CHANGED
@@ -65,39 +65,37 @@ class ImageProcessor:
65
 
66
  def _check_nsfw(self, attempts=1):
67
  """Check if image is NSFW using Hugging Face API."""
 
 
 
 
 
 
 
68
  try:
69
- # Save current image temporarily
70
- temp_buffer = io.BytesIO()
71
- self.image.save(temp_buffer, format='PNG')
72
- temp_buffer.seek(0)
73
-
74
- API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
75
  response = requests.request("POST", API_URL, headers=headers, data=temp_buffer.getvalue())
76
- decoded_response = response.content.decode("utf-8")
77
-
78
- json_response = json.loads(decoded_response)
79
-
80
  if "error" in json_response:
 
 
81
  time.sleep(json_response["estimated_time"])
82
- return self._check_nsfw(attempts+1)
83
 
84
- scores = {item['label']: item['score'] for item in json_response}
85
- nsfw_score = scores.get('nsfw', 0)
86
  print(f"NSFW Score: {nsfw_score}")
87
 
88
  if nsfw_score > 0.1:
89
- raise ImageError("Image <b>Not Appropriate</b>")
90
-
91
  return self
92
 
93
  except json.JSONDecodeError as e:
94
- print(f'JSON Decoding Error: {e}')
95
- raise ImageError("NSFW check failed")
96
  except Exception as e:
97
- print(f'NSFW Check Error: {e}')
98
  if attempts > 30:
99
  raise ImageError("NSFW check failed after multiple attempts")
100
- return self._check_nsfw(attempts+1)
101
 
102
  def _convert_color_mode(self):
103
  """Handle color mode conversion."""
@@ -142,12 +140,16 @@ class ImageProcessor:
142
 
143
  def process(self, min_size=320, max_size=4096, max_pixels=4194304):
144
  """Process image with all necessary transformations."""
145
- return (self
146
- ._convert_color_mode()
147
- ._resize_for_pixels(max_pixels)
148
- ._ensure_dimensions(min_size, max_size)
149
- ._check_nsfw() # Add NSFW check before encoding
150
- .encode())
 
 
 
 
151
 
152
  # Function to generate an image using Amazon Nova Canvas model
153
  class BedrockClient:
@@ -281,11 +283,11 @@ def check_rate_limit(body):
281
  # Check limits based on quality
282
  if quality == 'premium':
283
  if len(rate_data['premium']) >= 2:
284
- raise ImageError("<div style='text-align: center;'>Premium rate limit exceeded. Check back later or you use the <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a>.</div>")
285
  rate_data['premium'].append(current_time)
286
  else: # standard
287
  if len(rate_data['standard']) >= 4:
288
- raise ImageError("<div style='text-align: center;'>Standard rate limit exceeded. Check back later or you use the <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a>.</div>")
289
  rate_data['standard'].append(current_time)
290
 
291
  # Update rate limit file
 
65
 
66
  def _check_nsfw(self, attempts=1):
67
  """Check if image is NSFW using Hugging Face API."""
68
+ API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
69
+
70
+ # Prepare image data
71
+ temp_buffer = io.BytesIO()
72
+ self.image.save(temp_buffer, format='PNG')
73
+ temp_buffer.seek(0)
74
+
75
  try:
 
 
 
 
 
 
76
  response = requests.request("POST", API_URL, headers=headers, data=temp_buffer.getvalue())
77
+ json_response = json.loads(response.content.decode("utf-8"))
78
+ print(json_response)
 
 
79
  if "error" in json_response:
80
+ if attempts > 30:
81
+ raise ImageError("NSFW check failed after multiple attempts")
82
  time.sleep(json_response["estimated_time"])
83
+ return self._check_nsfw(attempts + 1)
84
 
85
+ nsfw_score = next((item['score'] for item in json_response if item['label'] == 'nsfw'), 0)
 
86
  print(f"NSFW Score: {nsfw_score}")
87
 
88
  if nsfw_score > 0.1:
89
+ return None
90
+
91
  return self
92
 
93
  except json.JSONDecodeError as e:
94
+ raise ImageError(f"NSFW check failed: Invalid response format - {str(e)}")
 
95
  except Exception as e:
 
96
  if attempts > 30:
97
  raise ImageError("NSFW check failed after multiple attempts")
98
+ return self._check_nsfw(attempts + 1)
99
 
100
  def _convert_color_mode(self):
101
  """Handle color mode conversion."""
 
140
 
141
  def process(self, min_size=320, max_size=4096, max_pixels=4194304):
142
  """Process image with all necessary transformations."""
143
+ result = (self
144
+ ._convert_color_mode()
145
+ ._resize_for_pixels(max_pixels)
146
+ ._ensure_dimensions(min_size, max_size)
147
+ ._check_nsfw()) # Add NSFW check before encoding
148
+
149
+ if result is None:
150
+ raise ImageError("Image <b>Not Appropriate</b>")
151
+
152
+ return result.encode()
153
 
154
  # Function to generate an image using Amazon Nova Canvas model
155
  class BedrockClient:
 
283
  # Check limits based on quality
284
  if quality == 'premium':
285
  if len(rate_data['premium']) >= 2:
286
+ raise ImageError("<div style='text-align: center;'>Premium rate limit exceeded. Check back later or use the <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a>.</div>")
287
  rate_data['premium'].append(current_time)
288
  else: # standard
289
  if len(rate_data['standard']) >= 4:
290
+ raise ImageError("<div style='text-align: center;'>Standard rate limit exceeded. Check back later or use the <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a>.</div>")
291
  rate_data['standard'].append(current_time)
292
 
293
  # Update rate limit file