Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -66,8 +66,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
66 |
return seed
|
67 |
|
68 |
|
69 |
-
def get_depth_map(image
|
70 |
original_size = (image.size[1], image.size[0])
|
|
|
71 |
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
72 |
with torch.no_grad(), torch.autocast("cuda"):
|
73 |
depth_map = depth_estimator(image).predicted_depth
|
@@ -83,10 +84,11 @@ def get_depth_map(image, progress):
|
|
83 |
image = torch.cat([depth_map] * 3, dim=1)
|
84 |
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
85 |
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
|
|
86 |
return image
|
87 |
|
88 |
|
89 |
-
def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name
|
90 |
print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
|
91 |
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
92 |
|
@@ -119,10 +121,11 @@ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, contr
|
|
119 |
orginal_image = Image.fromarray(image)
|
120 |
|
121 |
size = (orginal_image.size[0], orginal_image.size[1])
|
122 |
-
print("image size", size)
|
123 |
depth_image = get_depth_map(orginal_image, progress)
|
124 |
generator = torch.Generator().manual_seed(seed)
|
125 |
print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
|
|
|
126 |
generated_image = pipe(
|
127 |
prompt=prompt,
|
128 |
negative_prompt=n_prompt,
|
@@ -134,14 +137,14 @@ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, contr
|
|
134 |
generator=generator,
|
135 |
image=depth_image
|
136 |
).images[0]
|
137 |
-
|
138 |
if upload_to_s3:
|
139 |
url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket, progress)
|
140 |
result = {"status": "success", "url": url}
|
141 |
else:
|
142 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
143 |
|
144 |
-
return
|
145 |
|
146 |
with gr.Blocks() as demo:
|
147 |
|
@@ -172,7 +175,7 @@ with gr.Blocks() as demo:
|
|
172 |
|
173 |
|
174 |
with gr.Column():
|
175 |
-
|
176 |
logs = gr.Textbox(label="logs")
|
177 |
|
178 |
inputs = [
|
@@ -199,8 +202,8 @@ with gr.Blocks() as demo:
|
|
199 |
).then(
|
200 |
fn=process,
|
201 |
inputs=inputs,
|
202 |
-
outputs=[
|
203 |
-
api_name=
|
204 |
)
|
205 |
|
206 |
demo.queue().launch()
|
|
|
66 |
return seed
|
67 |
|
68 |
|
69 |
+
def get_depth_map(image):
|
70 |
original_size = (image.size[1], image.size[0])
|
71 |
+
print("start generate depth", original_size)
|
72 |
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
73 |
with torch.no_grad(), torch.autocast("cuda"):
|
74 |
depth_map = depth_estimator(image).predicted_depth
|
|
|
84 |
image = torch.cat([depth_map] * 3, dim=1)
|
85 |
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
86 |
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
87 |
+
print("generate depth success")
|
88 |
return image
|
89 |
|
90 |
|
91 |
+
def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name):
|
92 |
print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
|
93 |
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
94 |
|
|
|
121 |
orginal_image = Image.fromarray(image)
|
122 |
|
123 |
size = (orginal_image.size[0], orginal_image.size[1])
|
124 |
+
print("gorinal image size", size)
|
125 |
depth_image = get_depth_map(orginal_image, progress)
|
126 |
generator = torch.Generator().manual_seed(seed)
|
127 |
print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
|
128 |
+
print("run pipe")
|
129 |
generated_image = pipe(
|
130 |
prompt=prompt,
|
131 |
negative_prompt=n_prompt,
|
|
|
137 |
generator=generator,
|
138 |
image=depth_image
|
139 |
).images[0]
|
140 |
+
print("geneate image success")
|
141 |
if upload_to_s3:
|
142 |
url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket, progress)
|
143 |
result = {"status": "success", "url": url}
|
144 |
else:
|
145 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
146 |
|
147 |
+
return generated_image, json.dumps(result)
|
148 |
|
149 |
with gr.Blocks() as demo:
|
150 |
|
|
|
175 |
|
176 |
|
177 |
with gr.Column():
|
178 |
+
result = gr.Image(label="Generated Image")
|
179 |
logs = gr.Textbox(label="logs")
|
180 |
|
181 |
inputs = [
|
|
|
202 |
).then(
|
203 |
fn=process,
|
204 |
inputs=inputs,
|
205 |
+
outputs=[result, logs],
|
206 |
+
api_name="predict"
|
207 |
)
|
208 |
|
209 |
demo.queue().launch()
|