Munaf1987 commited on
Commit
b294284
·
verified ·
1 Parent(s): a08121a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -76
app.py CHANGED
@@ -1,93 +1,87 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
- from PIL import Image
5
  from diffusers import StableDiffusionInpaintPipeline
6
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, pipeline as hf_pipeline
7
  from segment_anything import sam_model_registry, SamPredictor
8
- import spaces
9
- # Device setup
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load zero-shot detector
13
- dino_id = "IDEA-Research/grounding-dino-tiny"
14
- dino_processor = AutoProcessor.from_pretrained(dino_id)
15
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_id).to(device)
 
16
 
17
- # Load SAM
18
- sam_checkpoint = "Munaf1987/sam"
19
- sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
20
  sam.to(device)
21
  predictor = SamPredictor(sam)
22
 
23
- # Load the inpainting pipeline
 
24
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
25
- "stabilityai/stable-diffusion-2-inpainting",
26
- torch_dtype=torch.float16 if device=="cuda" else torch.float32
27
- ).to(device)
28
- @spaces.GPU
29
- def detect_and_mask(image, prompt="a person"):
30
- inputs = dino_processor(images=image, text=prompt, return_tensors="pt").to(device)
31
- with torch.no_grad():
32
- outputs = dino_model(**inputs)
33
- results = dino_processor.post_process_grounded_object_detection(
34
- outputs, inputs.input_ids, box_threshold=0.3, text_threshold=0.25,
35
- target_sizes=[image.size[::-1]]
 
 
 
 
 
 
 
 
 
 
36
  )
37
- boxes = results[0]["boxes"]
38
- height, width = image.size[1], image.size[0]
39
-
40
- if len(boxes) == 0:
41
- return None, None, "No humans detected."
42
-
43
- # Build mask from boxes
44
- mask_full = Image.new("L", image.size, 0)
45
- for box in boxes:
46
- x1, y1, x2, y2 = map(int, box)
47
- mask_full.paste(255, (x1, y1, x2, y2))
48
-
49
- predictor.set_image(np.array(image))
50
- transformed = predictor.transform.apply_boxes(boxes.cpu().numpy(), image.size[::-1])
51
- sam_masks, _, _ = predictor.predict(boxes=transformed, multimask_output=False)
52
- combined = np.zeros_like(sam_masks[0], dtype=np.uint8)
53
- for m in sam_masks:
54
- combined = np.maximum(combined, m.astype(np.uint8))
55
-
56
- mask_image = Image.fromarray(combined * 255).convert("L")
57
- return image, mask_image, "Mask ready."
58
- @spaces.GPU
59
- def inpaint_background(image, mask, prompt="background"):
60
- orig_size = image.size
61
- # Resize to inpainting model's resolution
62
- img512 = image.resize((512,512), Image.LANCZOS)
63
- m512 = mask.resize((512,512), Image.LANCZOS)
64
- result = pipe(prompt=prompt, image=img512, mask_image=m512).images[0]
65
- return result.resize(orig_size, Image.LANCZOS), "Background inpainted."
66
- @spaces.GPU
67
- def replace_with_cartoon(image, mask, prompt="a cartoon human in place"):
68
- orig_size = image.size
69
- img512 = image.resize((512,512), Image.LANCZOS)
70
- m512 = mask.resize((512,512), Image.LANCZOS)
71
- result = pipe(prompt=prompt, image=img512, mask_image=m512).images[0]
72
- return result.resize(orig_size, Image.LANCZOS), "Replaced with cartoon."
73
 
74
  # Gradio UI
75
  with gr.Blocks() as demo:
76
- gr.Markdown("### Remove or Replace Humans with a Cartoon Character")
77
-
78
- img = gr.Image(type="pil")
79
- detect_prompt = gr.Textbox(label="Detection text prompt", value="a person")
80
- detect_btn = gr.Button("Detect Humans")
81
- mask_out = gr.Image(type="pil", label="Detected Mask")
82
- status1 = gr.Textbox(interactive=False)
83
-
84
- bg_btn = gr.Button("Remove Humans (Background Fill)")
85
- cartoon_btn = gr.Button("Replace with Cartoon")
86
- out_img = gr.Image(type="pil", label="Final Output")
87
- status2 = gr.Textbox(interactive=False)
88
-
89
- detect_btn.click(detect_and_mask, inputs=[img, detect_prompt], outputs=[img, mask_out, status1])
90
- bg_btn.click(inpaint_background, inputs=[img, mask_out], outputs=[out_img, status2])
91
- cartoon_btn.click(replace_with_cartoon, inputs=[img, mask_out], outputs=[out_img, status2])
92
 
93
  demo.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import torch
4
  from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image
6
  from segment_anything import sam_model_registry, SamPredictor
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Device configuration
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Step 1: Download SAM model checkpoint from Hugging Face
13
+ checkpoint_path = hf_hub_download(
14
+ repo_id="Munaf1987/sam", # ✅ Your model repo
15
+ filename="sam_vit_h_4b8939.pth", # ✅ The exact filename in your repo
16
+ )
17
 
18
+ # Step 2: Load SAM model
19
+ model_type = "vit_h"
20
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
21
  sam.to(device)
22
  predictor = SamPredictor(sam)
23
 
24
+ # Step 3: Load Stable Diffusion Inpainting Pipeline
25
+ sta_diff_model = "stabilityai/stable-diffusion-2-inpainting"
26
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
27
+ sta_diff_model,
28
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
29
+ )
30
+ pipe = pipe.to(device)
31
+
32
+ # Global variable to store selected pixels
33
+ selected_pixels = []
34
+
35
+
36
+ def generate_mask(input_image, evt: gr.SelectData):
37
+ """Generate mask based on user-selected points."""
38
+ selected_pixels.append((evt.index[0], evt.index[1])) # evt.index is (x, y)
39
+
40
+ predictor.set_image(np.array(input_image))
41
+ input_points = np.array(selected_pixels)
42
+ input_labels = np.ones(input_points.shape[0])
43
+
44
+ masks, _, _ = predictor.predict(
45
+ point_coords=input_points,
46
+ point_labels=input_labels,
47
+ multimask_output=False,
48
  )
49
+
50
+ mask = masks[0] * 255
51
+ mask_image = Image.fromarray(mask.astype(np.uint8)).convert("L")
52
+ return mask_image
53
+
54
+
55
+ def inpaint(input_image, mask_image, prompt):
56
+ """Run the inpainting model."""
57
+ if input_image is None or mask_image is None or prompt == "":
58
+ return None
59
+
60
+ # ✅ Resize mask but keep the input image original size
61
+ mask_image_resized = mask_image.resize(input_image.size)
62
+
63
+ output = pipe(
64
+ prompt=prompt,
65
+ image=input_image,
66
+ mask_image=mask_image_resized
67
+ ).images[0]
68
+
69
+ return output
70
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Gradio UI
73
  with gr.Blocks() as demo:
74
+ gr.Markdown("## Stable Diffusion Inpainting with SAM Mask Selection")
75
+
76
+ with gr.Row():
77
+ input_image = gr.Image(type="pil", label="Input Image", interactive=True)
78
+ mask_display = gr.Image(type="pil", label="Generated Mask")
79
+ output_image = gr.Image(type="pil", label="Output Image")
80
+
81
+ prompt_text = gr.Textbox(label="Prompt", placeholder="Enter a prompt for inpainting")
82
+ submit = gr.Button("Submit")
83
+
84
+ input_image.select(generate_mask, inputs=input_image, outputs=mask_display)
85
+ submit.click(inpaint, inputs=[input_image, mask_display, prompt_text], outputs=output_image)
 
 
 
 
86
 
87
  demo.launch()