wanghuging commited on
Commit
bb82241
·
1 Parent(s): a13309f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -43
app.py CHANGED
@@ -88,41 +88,65 @@ def resize_crop(image, size=512):
88
  return image
89
 
90
 
91
- async def predict(init_image, prompt, strength, steps, seed=1231231):
92
- init_image = None
93
- if init_image is not None:
94
- init_image = resize_crop(init_image)
95
- generator = torch.manual_seed(seed)
96
- last_time = time.time()
97
 
98
- if int(steps * strength) < 1:
99
- steps = math.ceil(1 / max(0.10, strength))
100
 
101
- results = i2i_pipe(
102
- prompt=prompt,
103
- image=init_image,
104
- generator=generator,
105
- num_inference_steps=steps,
106
- guidance_scale=0.0,
107
- strength=strength,
108
- width=512,
109
- height=512,
110
- output_type="pil",
111
- )
112
- else:
113
- generator = torch.manual_seed(seed)
114
- last_time = time.time()
115
- t2i_pipe.safety_checker = None
116
- t2i_pipe.requires_safety_checker = False
117
- results = t2i_pipe(
118
- prompt=prompt,
119
- generator=generator,
120
- num_inference_steps=steps,
121
- guidance_scale=0.0,
122
- width=512,
123
- height=512,
124
- output_type="pil",
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  print(f"Pipe took {time.time() - last_time} seconds")
127
  nsfw_content_detected = (
128
  results.nsfw_content_detected[0]
@@ -134,7 +158,6 @@ async def predict(init_image, prompt, strength, steps, seed=1231231):
134
  return Image.new("RGB", (512, 512))
135
  return results.images[0]
136
 
137
-
138
  css = """
139
  #container{
140
  margin: 0 auto;
@@ -217,20 +240,20 @@ with gr.Blocks(css=css) as demo:
217
  # ```
218
  # """
219
  # )
220
- image_input = None
221
- inputs = [image_input, prompt, strength, steps, seed]
222
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
223
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
224
  steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
225
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
226
  strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
227
- image_input.change(
228
- fn=lambda x: x,
229
- inputs=image_input,
230
- outputs=init_image_state,
231
- show_progress=False,
232
- queue=False,
233
- )
234
 
235
  demo.queue()
236
  demo.launch()
 
88
  return image
89
 
90
 
91
+ # async def predict(init_image, prompt, strength, steps, seed=1231231):
92
+ # # init_image = None
93
+ # if init_image is not None:
94
+ # init_image = resize_crop(init_image)
95
+ # generator = torch.manual_seed(seed)
96
+ # last_time = time.time()
97
 
98
+ # if int(steps * strength) < 1:
99
+ # steps = math.ceil(1 / max(0.10, strength))
100
 
101
+ # results = i2i_pipe(
102
+ # prompt=prompt,
103
+ # image=init_image,
104
+ # generator=generator,
105
+ # num_inference_steps=steps,
106
+ # guidance_scale=0.0,
107
+ # strength=strength,
108
+ # width=512,
109
+ # height=512,
110
+ # output_type="pil",
111
+ # )
112
+ # else:
113
+ # generator = torch.manual_seed(seed)
114
+ # last_time = time.time()
115
+ # t2i_pipe.safety_checker = None
116
+ # t2i_pipe.requires_safety_checker = False
117
+ # results = t2i_pipe(
118
+ # prompt=prompt,
119
+ # generator=generator,
120
+ # num_inference_steps=steps,
121
+ # guidance_scale=0.0,
122
+ # width=512,
123
+ # height=512,
124
+ # output_type="pil",
125
+ # )
126
+ # print(f"Pipe took {time.time() - last_time} seconds")
127
+ # nsfw_content_detected = (
128
+ # results.nsfw_content_detected[0]
129
+ # if "nsfw_content_detected" in results
130
+ # else False
131
+ # )
132
+ # if nsfw_content_detected:
133
+ # gr.Warning("NSFW content detected.")
134
+ # return Image.new("RGB", (512, 512))
135
+ # return results.images[0]
136
+ async def predict(prompt, strength, steps, seed=1231231):
137
+ generator = torch.manual_seed(seed)
138
+ last_time = time.time()
139
+ t2i_pipe.safety_checker = None
140
+ t2i_pipe.requires_safety_checker = False
141
+ results = t2i_pipe(
142
+ prompt=prompt,
143
+ generator=generator,
144
+ num_inference_steps=steps,
145
+ guidance_scale=0.0,
146
+ width=512,
147
+ height=512,
148
+ output_type="pil",
149
+ )
150
  print(f"Pipe took {time.time() - last_time} seconds")
151
  nsfw_content_detected = (
152
  results.nsfw_content_detected[0]
 
158
  return Image.new("RGB", (512, 512))
159
  return results.images[0]
160
 
 
161
  css = """
162
  #container{
163
  margin: 0 auto;
 
240
  # ```
241
  # """
242
  # )
243
+ inputs = [prompt, strength, steps, seed]
244
+ # inputs = [image_input, prompt, strength, steps, seed]
245
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
246
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
247
  steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
248
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
249
  strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
250
+ # image_input.change(
251
+ # fn=lambda x: x,
252
+ # inputs=image_input,
253
+ # outputs=init_image_state,
254
+ # show_progress=False,
255
+ # queue=False,
256
+ # )
257
 
258
  demo.queue()
259
  demo.launch()