Munaf1987 commited on
Commit
5287b40
·
verified ·
1 Parent(s): 214f5df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -54
app.py CHANGED
@@ -1,89 +1,68 @@
1
  import gradio as gr
2
- import numpy as np
3
  import torch
 
4
  from diffusers import StableDiffusionInpaintPipeline
5
- from PIL import Image
6
- from segment_anything import sam_model_registry, SamPredictor
7
- from huggingface_hub import hf_hub_download
8
  import spaces
9
 
10
- # Device configuration
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Step 1: Download SAM model checkpoint from Hugging Face
14
- checkpoint_path = hf_hub_download(
15
- repo_id="Munaf1987/sam",
16
- filename="sam_vit_h_4b8939.pth",
17
- )
18
-
19
- # Step 2: Load SAM model
20
- model_type = "vit_h"
21
- sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
22
- sam.to(device)
23
- predictor = SamPredictor(sam)
24
-
25
- # Step 3: Load Stable Diffusion Inpainting Pipeline
26
- sta_diff_model = "stabilityai/stable-diffusion-2-inpainting"
27
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
28
- sta_diff_model,
29
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
30
- )
31
- pipe = pipe.to(device)
32
 
33
- # Global variables to store selected pixels and mask
34
- selected_pixels = []
35
- generated_mask = None
36
 
37
  @spaces.GPU
38
- def generate_mask(input_image, evt: gr.SelectData):
39
- global generated_mask
40
- selected_pixels.append((evt.index[0], evt.index[1]))
41
 
42
- predictor.set_image(np.array(input_image))
43
- input_points = np.array(selected_pixels)
44
- input_labels = np.ones(input_points.shape[0])
45
 
46
- masks, _, _ = predictor.predict(
47
- point_coords=input_points,
48
- point_labels=input_labels,
49
- multimask_output=False,
50
- )
51
 
52
- mask = masks[0] * 255
53
- mask_image = Image.fromarray(mask.astype(np.uint8)).convert("L")
54
- generated_mask = mask_image
55
 
56
- return mask_image
 
57
 
58
- @spaces.GPU
59
- def inpaint(input_image, prompt):
60
- global generated_mask
61
- if input_image is None or generated_mask is None or prompt == "":
62
- return None
63
 
64
- mask_image_resized = generated_mask.resize(input_image.size)
 
65
 
 
66
  output = pipe(
67
  prompt=prompt,
68
  image=input_image,
69
- mask_image=mask_image_resized
70
  ).images[0]
71
 
72
  return output
73
 
74
  # Gradio UI
75
  with gr.Blocks() as demo:
76
- gr.Markdown("## Stable Diffusion Inpainting with SAM Mask Selection")
77
 
78
  with gr.Row():
79
- input_image = gr.Image(type="pil", label="Input Image", interactive=True)
80
- mask_display = gr.Image(type="pil", label="Generated Mask")
81
  output_image = gr.Image(type="pil", label="Output Image")
82
 
83
- prompt_text = gr.Textbox(label="Prompt", placeholder="Enter a prompt for inpainting")
84
  submit = gr.Button("Submit")
85
 
86
- input_image.select(generate_mask, inputs=input_image, outputs=mask_display)
87
- submit.click(inpaint, inputs=[input_image, prompt_text], outputs=output_image)
88
 
89
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ import numpy as np
4
  from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image, ImageDraw
6
+ from transformers import DetrImageProcessor, DetrForObjectDetection
 
7
  import spaces
8
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load the Stable Diffusion Inpainting model
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
13
+ "stabilityai/stable-diffusion-2-inpainting",
14
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
15
+ ).to(device)
 
16
 
17
+ # Load the DETR object detection model
18
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
19
+ detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
20
 
21
  @spaces.GPU
22
+ def detect_and_remove(input_image, prompt):
23
+ if input_image is None or prompt == "":
24
+ return None
25
 
26
+ image_np = np.array(input_image)
27
+ inputs = processor(images=input_image, return_tensors="pt").to(device)
 
28
 
29
+ outputs = detector(**inputs)
30
+ target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
 
 
 
31
 
32
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
 
 
33
 
34
+ mask = Image.new("L", input_image.size, 0)
35
+ draw = ImageDraw.Draw(mask)
36
 
37
+ # Draw boxes for "person" class only
38
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
39
+ if detector.config.id2label[label.item()] == "person":
40
+ box = [int(i) for i in box.tolist()]
41
+ draw.rectangle(box, fill=255)
42
 
43
+ if np.array(mask).sum() == 0:
44
+ return "No human detected."
45
 
46
+ # Inpainting
47
  output = pipe(
48
  prompt=prompt,
49
  image=input_image,
50
+ mask_image=mask
51
  ).images[0]
52
 
53
  return output
54
 
55
  # Gradio UI
56
  with gr.Blocks() as demo:
57
+ gr.Markdown("## Automatic Human Removal and Inpainting")
58
 
59
  with gr.Row():
60
+ input_image = gr.Image(type="pil", label="Input Image")
 
61
  output_image = gr.Image(type="pil", label="Output Image")
62
 
63
+ prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background")
64
  submit = gr.Button("Submit")
65
 
66
+ submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image)
 
67
 
68
  demo.launch()