jiuface commited on
Commit
3594837
·
verified ·
1 Parent(s): eebdcc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -66,8 +66,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  return seed
67
 
68
 
69
- def get_depth_map(image, progress):
70
  original_size = (image.size[1], image.size[0])
 
71
  image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
72
  with torch.no_grad(), torch.autocast("cuda"):
73
  depth_map = depth_estimator(image).predicted_depth
@@ -83,10 +84,11 @@ def get_depth_map(image, progress):
83
  image = torch.cat([depth_map] * 3, dim=1)
84
  image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
85
  image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
 
86
  return image
87
 
88
 
89
- def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name, progress):
90
  print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
91
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
92
 
@@ -119,10 +121,11 @@ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, contr
119
  orginal_image = Image.fromarray(image)
120
 
121
  size = (orginal_image.size[0], orginal_image.size[1])
122
- print("image size", size)
123
  depth_image = get_depth_map(orginal_image, progress)
124
  generator = torch.Generator().manual_seed(seed)
125
  print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
 
126
  generated_image = pipe(
127
  prompt=prompt,
128
  negative_prompt=n_prompt,
@@ -134,14 +137,14 @@ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, contr
134
  generator=generator,
135
  image=depth_image
136
  ).images[0]
137
-
138
  if upload_to_s3:
139
  url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket, progress)
140
  result = {"status": "success", "url": url}
141
  else:
142
  result = {"status": "success", "message": "Image generated but not uploaded"}
143
 
144
- return [orginal_image, generated_image], json.dumps(result)
145
 
146
  with gr.Blocks() as demo:
147
 
@@ -172,7 +175,7 @@ with gr.Blocks() as demo:
172
 
173
 
174
  with gr.Column():
175
- images = ImageSlider(label="Generate images", type="pil", slider_color="pink")
176
  logs = gr.Textbox(label="logs")
177
 
178
  inputs = [
@@ -199,8 +202,8 @@ with gr.Blocks() as demo:
199
  ).then(
200
  fn=process,
201
  inputs=inputs,
202
- outputs=[images, logs],
203
- api_name=False
204
  )
205
 
206
  demo.queue().launch()
 
66
  return seed
67
 
68
 
69
+ def get_depth_map(image):
70
  original_size = (image.size[1], image.size[0])
71
+ print("start generate depth", original_size)
72
  image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
73
  with torch.no_grad(), torch.autocast("cuda"):
74
  depth_map = depth_estimator(image).predicted_depth
 
84
  image = torch.cat([depth_map] * 3, dim=1)
85
  image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
86
  image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
87
+ print("generate depth success")
88
  return image
89
 
90
 
91
+ def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name):
92
  print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
93
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
94
 
 
121
  orginal_image = Image.fromarray(image)
122
 
123
  size = (orginal_image.size[0], orginal_image.size[1])
124
+ print("gorinal image size", size)
125
  depth_image = get_depth_map(orginal_image, progress)
126
  generator = torch.Generator().manual_seed(seed)
127
  print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
128
+ print("run pipe")
129
  generated_image = pipe(
130
  prompt=prompt,
131
  negative_prompt=n_prompt,
 
137
  generator=generator,
138
  image=depth_image
139
  ).images[0]
140
+ print("geneate image success")
141
  if upload_to_s3:
142
  url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket, progress)
143
  result = {"status": "success", "url": url}
144
  else:
145
  result = {"status": "success", "message": "Image generated but not uploaded"}
146
 
147
+ return generated_image, json.dumps(result)
148
 
149
  with gr.Blocks() as demo:
150
 
 
175
 
176
 
177
  with gr.Column():
178
+ result = gr.Image(label="Generated Image")
179
  logs = gr.Textbox(label="logs")
180
 
181
  inputs = [
 
202
  ).then(
203
  fn=process,
204
  inputs=inputs,
205
+ outputs=[result, logs],
206
+ api_name="predict"
207
  )
208
 
209
  demo.queue().launch()