vvaibhav commited on
Commit
452ea00
·
verified ·
1 Parent(s): b435a8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -8,16 +8,16 @@ from transformers import SamModel, SamProcessor
8
  from diffusers import StableDiffusionInpaintPipeline
9
  import io
10
 
11
- # Initialize SAM model and processor
12
- sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
13
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
14
 
15
- # Initialize Inpainting pipeline
16
  inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
17
  "runwayml/stable-diffusion-inpainting",
18
- torch_dtype=torch.float16
19
- ).to("cuda")
20
- inpaint_pipeline.enable_model_cpu_offload()
21
 
22
  def mask_to_rgba(mask):
23
  """
@@ -48,7 +48,7 @@ def generate_mask(image, input_points):
48
  points = [tuple(point) for point in input_points]
49
 
50
  # Prepare inputs for SAM
51
- inputs = sam_processor(image, points=points, return_tensors="pt").to("cuda")
52
 
53
  with torch.no_grad():
54
  outputs = sam_model(**inputs)
@@ -91,7 +91,7 @@ def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
91
 
92
  mask_image = Image.fromarray((mask * 255).astype(np.uint8))
93
 
94
- generator = torch.Generator("cuda").manual_seed(seed)
95
 
96
  try:
97
  result = inpaint_pipeline(
@@ -157,7 +157,7 @@ with gr.Blocks() as demo:
157
 
158
  with gr.Row():
159
  with gr.Column():
160
- image_input = gr.Image(label="Upload Image", type="pil")
161
  prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
162
  negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
163
  seed_input = gr.Number(label="Seed", value=42)
@@ -167,11 +167,17 @@ with gr.Blocks() as demo:
167
  masked_output = gr.Image(label="Selected Object Mask Overlay")
168
  augmented_output = gr.Image(label="Augmented Image")
169
 
170
- image_input.change(fn=lambda img: img, inputs=image_input, outputs=masked_output)
 
 
 
 
 
 
171
 
172
  process_button.click(
173
  fn=process,
174
- inputs=[image_input, gr.State(), prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
175
  outputs=[masked_output, augmented_output]
176
  )
177
 
 
8
  from diffusers import StableDiffusionInpaintPipeline
9
  import io
10
 
11
+ # Initialize SAM model and processor on CPU
12
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
13
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
14
 
15
+ # Initialize Inpainting pipeline on CPU
16
  inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
17
  "runwayml/stable-diffusion-inpainting",
18
+ torch_dtype=torch.float32
19
+ ).to("cpu")
20
+ # Removed model_cpu_offload as it's unnecessary for CPU
21
 
22
  def mask_to_rgba(mask):
23
  """
 
48
  points = [tuple(point) for point in input_points]
49
 
50
  # Prepare inputs for SAM
51
+ inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
52
 
53
  with torch.no_grad():
54
  outputs = sam_model(**inputs)
 
91
 
92
  mask_image = Image.fromarray((mask * 255).astype(np.uint8))
93
 
94
+ generator = torch.Generator("cpu").manual_seed(seed)
95
 
96
  try:
97
  result = inpaint_pipeline(
 
157
 
158
  with gr.Row():
159
  with gr.Column():
160
+ image_input = gr.Image(label="Upload Image", type="pil", tool="point", interactive=True)
161
  prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
162
  negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
163
  seed_input = gr.Number(label="Seed", value=42)
 
167
  masked_output = gr.Image(label="Selected Object Mask Overlay")
168
  augmented_output = gr.Image(label="Augmented Image")
169
 
170
+ # Capture points selected on the image
171
+ points = gr.State([])
172
+
173
+ def update_points(selected_points):
174
+ return selected_points
175
+
176
+ image_input.select(update_points, inputs=image_input, outputs=points)
177
 
178
  process_button.click(
179
  fn=process,
180
+ inputs=[image_input, points, prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
181
  outputs=[masked_output, augmented_output]
182
  )
183