Munaf1987 commited on
Commit
59be1d1
·
verified ·
1 Parent(s): aa9406c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from diffusers import StableDiffusionInpaintPipeline
6
+ from segment_anything import sam_model_registry, SamPredictor
7
+ from groundingdino.util.inference import load_model, load_image, predict, annotate
8
+
9
+ # Device configuration
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load Grounding DINO (human detection)
13
+ grounding_model = load_model("ShilongLiu/GroundingDINO-SwinB") # Public Hugging Face model
14
+
15
+ # Load SAM model
16
+ sam_checkpoint = "facebook/sam-vit-huge"
17
+ sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
18
+ sam.to(device)
19
+ predictor = SamPredictor(sam)
20
+
21
+ # Load Stable Diffusion Inpainting Pipeline
22
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-2-inpainting",
24
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
25
+ )
26
+ pipe = pipe.to(device)
27
+
28
+ def detect_and_segment(input_image, prompt):
29
+ # Convert image to numpy
30
+ image_np = np.array(input_image)
31
+ predictor.set_image(image_np)
32
+
33
+ # Grounding DINO detection
34
+ boxes, logits, phrases = predict(
35
+ model=grounding_model,
36
+ image=input_image,
37
+ caption=prompt,
38
+ box_threshold=0.35,
39
+ text_threshold=0.25
40
+ )
41
+
42
+ if len(boxes) == 0:
43
+ return None, None, "No objects detected."
44
+
45
+ # Prepare mask
46
+ transformed_boxes = boxes * torch.tensor([input_image.width, input_image.height, input_image.width, input_image.height])
47
+ transformed_boxes = transformed_boxes.cpu().numpy()
48
+
49
+ input_points = []
50
+ input_labels = []
51
+ for box in transformed_boxes:
52
+ x_center = int((box[0] + box[2]) / 2)
53
+ y_center = int((box[1] + box[3]) / 2)
54
+ input_points.append([x_center, y_center])
55
+ input_labels.append(1)
56
+
57
+ masks, _, _ = predictor.predict(
58
+ point_coords=np.array(input_points),
59
+ point_labels=np.array(input_labels),
60
+ multimask_output=False,
61
+ )
62
+
63
+ final_mask = np.zeros_like(masks[0])
64
+ for mask in masks:
65
+ final_mask = np.logical_or(final_mask, mask)
66
+
67
+ final_mask = (final_mask * 255).astype(np.uint8)
68
+ mask_image = Image.fromarray(final_mask).convert("L")
69
+ return input_image, mask_image, "Mask generated successfully."
70
+
71
+ def inpaint(input_image, mask_image, inpaint_prompt):
72
+ if input_image is None or mask_image is None or inpaint_prompt == "":
73
+ return None
74
+
75
+ image_resized = input_image.resize((512, 512))
76
+ mask_resized = mask_image.resize((512, 512))
77
+
78
+ output = pipe(
79
+ prompt=inpaint_prompt,
80
+ image=image_resized,
81
+ mask_image=mask_resized
82
+ ).images[0]
83
+
84
+ # Resize back to original
85
+ output = output.resize(input_image.size)
86
+ return output
87
+
88
+ # Gradio UI
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown("## Remove Humans and Replace with Cartoon / Imaginary Characters")
91
+
92
+ with gr.Row():
93
+ input_image = gr.Image(type="pil", label="Upload Image")
94
+ mask_display = gr.Image(type="pil", label="Generated Mask")
95
+ output_image = gr.Image(type="pil", label="Final Output")
96
+
97
+ detect_prompt = gr.Textbox(label="Detection Prompt", value="human", placeholder="What objects to detect? (e.g., human)")
98
+ inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="What to replace with? (e.g., cartoon human, anime boy)")
99
+
100
+ detect_button = gr.Button("Detect and Generate Mask")
101
+ inpaint_button = gr.Button("Inpaint with Replacement")
102
+
103
+ detect_button.click(fn=detect_and_segment, inputs=[input_image, detect_prompt], outputs=[input_image, mask_display, gr.Textbox(label="Status")])
104
+ inpaint_button.click(fn=inpaint, inputs=[input_image, mask_display, inpaint_prompt], outputs=output_image)
105
+
106
+ demo.launch()