jiuface commited on
Commit
fc2d50f
·
1 Parent(s): d74c6d6
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -2,9 +2,8 @@ import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
- from diffusers import DiffusionPipeline
6
  import torch
7
- import random
8
  from diffusers import (
9
  ControlNetModel,
10
  DiffusionPipeline,
@@ -62,10 +61,10 @@ def get_depth_map(image):
62
  image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
63
  with torch.no_grad(), torch.autocast("cuda"):
64
  depth_map = depth_estimator(image).predicted_depth
65
-
66
  depth_map = torch.nn.functional.interpolate(
67
  depth_map.unsqueeze(1),
68
- size=(1024, 1024),
69
  mode="bicubic",
70
  align_corners=False,
71
  )
@@ -80,20 +79,21 @@ def get_depth_map(image):
80
 
81
 
82
  @spaces.GPU(enable_queue=True)
83
- def process(orginal_image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed):
84
 
85
  if image_url:
86
  orginal_image = load_image(image_url)
87
-
88
- width = 1024
89
- height = 1024
90
- depth_image = get_depth_map(orginal_image.resize((1024, 1024)))
 
91
  generator = torch.Generator().manual_seed(seed)
92
- generated_image = self.pipe(
93
  prompt=prompt,
94
  negative_prompt=n_prompt,
95
- width=width,
96
- height=height,
97
  guidance_scale=guidance_scale,
98
  num_inference_steps=num_steps,
99
  strength=control_strength,
 
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
+ import PIL.Image
6
  import torch
 
7
  from diffusers import (
8
  ControlNetModel,
9
  DiffusionPipeline,
 
61
  image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
62
  with torch.no_grad(), torch.autocast("cuda"):
63
  depth_map = depth_estimator(image).predicted_depth
64
+ size = (image.shape[-2], image.shape[-1])
65
  depth_map = torch.nn.functional.interpolate(
66
  depth_map.unsqueeze(1),
67
+ size=size,
68
  mode="bicubic",
69
  align_corners=False,
70
  )
 
79
 
80
 
81
  @spaces.GPU(enable_queue=True)
82
+ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed):
83
 
84
  if image_url:
85
  orginal_image = load_image(image_url)
86
+ else:
87
+ orginal_image = PIL.Image.fromarray(image)
88
+
89
+ size = (orginal_image.size[0], orginal_image.size[1])
90
+ depth_image = get_depth_map(orginal_image)
91
  generator = torch.Generator().manual_seed(seed)
92
+ generated_image = pipe(
93
  prompt=prompt,
94
  negative_prompt=n_prompt,
95
+ width=size[0],
96
+ height=size[1],
97
  guidance_scale=guidance_scale,
98
  num_inference_steps=num_steps,
99
  strength=control_strength,