pablo commited on
Commit
35ffdbd
·
1 Parent(s): a63d2a4

fix depth estimation

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -6,6 +6,7 @@ import diffuserslocal.src.diffusers as diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from diffuserslocal.src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d_inpaint import StableDiffusionLDM3DInpaintPipeline
8
  from PIL import Image
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
@@ -33,7 +34,8 @@ else:
33
 
34
 
35
  def estimate_depth(image: Image) -> Image:
36
-
 
37
  input_batch = transform(image).to(device)
38
 
39
  with torch.no_grad():
@@ -41,13 +43,18 @@ def estimate_depth(image: Image) -> Image:
41
 
42
  prediction = torch.nn.functional.interpolate(
43
  prediction.unsqueeze(1),
44
- size=image.size,
45
  mode="bicubic",
46
  align_corners=False,
47
  ).squeeze()
48
 
49
- return Image.fromarray(prediction.cpu().numpy())
50
-
 
 
 
 
 
51
  def read_content(file_path: str) -> str:
52
  """read the content of target file
53
  """
@@ -60,6 +67,9 @@ def predict(dict, depth, prompt="", negative_prompt="", guidance_scale=7.5, step
60
  if negative_prompt == "":
61
  negative_prompt = None
62
  scheduler_class_name = scheduler.split("-")[0]
 
 
 
63
 
64
  scheduler = getattr(diffusers, scheduler_class_name)
65
  pipe.scheduler = scheduler.from_pretrained("Intel/ldm3d-4c", subfolder="scheduler")
@@ -117,6 +127,8 @@ with image_blocks as demo:
117
  with gr.Column():
118
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload",height=400)
119
  depth = gr.Image(source='upload', elem_id="depth_upload", type="pil", label="Upload",height=400)
 
 
120
  with gr.Row(elem_id="prompt-container", mobile_collapse=False, equal_height=True):
121
  with gr.Row():
122
  prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
@@ -140,10 +152,6 @@ with image_blocks as demo:
140
  community_icon = gr.HTML(community_icon_html)
141
  loading_icon = gr.HTML(loading_icon_html)
142
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
143
-
144
-
145
- if (depth is None):
146
- depth = estimate_depth(image)
147
 
148
  btn.click(fn=predict, inputs=[image, depth, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, depth_out, share_btn_container], api_name='run')
149
  prompt.submit(fn=predict, inputs=[image, depth, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, depth_out, share_btn_container])
 
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from diffuserslocal.src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d_inpaint import StableDiffusionLDM3DInpaintPipeline
8
  from PIL import Image
9
+ import numpy as np
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
34
 
35
 
36
  def estimate_depth(image: Image) -> Image:
37
+ image= image.resize((384,384))
38
+ image = np.array(image)
39
  input_batch = transform(image).to(device)
40
 
41
  with torch.no_grad():
 
43
 
44
  prediction = torch.nn.functional.interpolate(
45
  prediction.unsqueeze(1),
46
+ size=image.shape[:2],
47
  mode="bicubic",
48
  align_corners=False,
49
  ).squeeze()
50
 
51
+ output = prediction.cpu().numpy()
52
+
53
+ output= 255 * output/np.max(output)
54
+
55
+ return Image.fromarray(output.astype("uint8"))
56
+
57
+
58
  def read_content(file_path: str) -> str:
59
  """read the content of target file
60
  """
 
67
  if negative_prompt == "":
68
  negative_prompt = None
69
  scheduler_class_name = scheduler.split("-")[0]
70
+
71
+ if (depth is None):
72
+ depth_image = estimate_depth(image)
73
 
74
  scheduler = getattr(diffusers, scheduler_class_name)
75
  pipe.scheduler = scheduler.from_pretrained("Intel/ldm3d-4c", subfolder="scheduler")
 
127
  with gr.Column():
128
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload",height=400)
129
  depth = gr.Image(source='upload', elem_id="depth_upload", type="pil", label="Upload",height=400)
130
+ print(depth)
131
+
132
  with gr.Row(elem_id="prompt-container", mobile_collapse=False, equal_height=True):
133
  with gr.Row():
134
  prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
 
152
  community_icon = gr.HTML(community_icon_html)
153
  loading_icon = gr.HTML(loading_icon_html)
154
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
 
 
 
 
155
 
156
  btn.click(fn=predict, inputs=[image, depth, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, depth_out, share_btn_container], api_name='run')
157
  prompt.submit(fn=predict, inputs=[image, depth, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, depth_out, share_btn_container])