yigitekin commited on
Commit
b8231cb
·
verified ·
1 Parent(s): 58602ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -14
app.py CHANGED
@@ -9,14 +9,8 @@ import cv2
9
  import numpy as np
10
  import argparse
11
 
12
- # Parse command line arguments
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file")
15
- parser.add_argument("--share", action="store_true", help="Share the interface if provided")
16
- args = parser.parse_args()
17
-
18
  # Load configuration and models
19
- config = OmegaConf.load(args.config)
20
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
21
  "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
22
  )
@@ -27,7 +21,7 @@ clipaway = CLIPAway(
27
  alpha_clip_path=config.alpha_clip_ckpt_pth,
28
  config=config,
29
  alpha_clip_id=config.alpha_clip_id,
30
- device=config.device,
31
  num_tokens=4
32
  )
33
 
@@ -37,7 +31,7 @@ def dilate_mask(mask, kernel_size=5, iterations=5):
37
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
38
  return Image.fromarray(mask)
39
 
40
- @spaces.GPU(duration=20)
41
  def remove_obj(image, uploaded_mask, seed):
42
  image_pil = image["image"].resize((512, 512), Image.ANTIALIAS)
43
  mask = dilate_mask(uploaded_mask)
@@ -98,8 +92,4 @@ with gr.Blocks(theme="gradio/monochrome") as demo:
98
  outputs=result_image
99
  )
100
 
101
- # Launch the interface without caching
102
- if args.share:
103
- demo.launch(share=True)
104
- else:
105
- demo.launch()
 
9
  import numpy as np
10
  import argparse
11
 
 
 
 
 
 
 
12
  # Load configuration and models
13
+ config = OmegaConf.load("config/inference_config.yaml")
14
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
15
  "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
16
  )
 
21
  alpha_clip_path=config.alpha_clip_ckpt_pth,
22
  config=config,
23
  alpha_clip_id=config.alpha_clip_id,
24
+ device="cuda",
25
  num_tokens=4
26
  )
27
 
 
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
+ @spaces.GPU
35
  def remove_obj(image, uploaded_mask, seed):
36
  image_pil = image["image"].resize((512, 512), Image.ANTIALIAS)
37
  mask = dilate_mask(uploaded_mask)
 
92
  outputs=result_image
93
  )
94
 
95
+ demo.launch()