Update app.py
Browse files
app.py
CHANGED
@@ -1,89 +1,68 @@
|
|
1 |
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
import torch
|
|
|
4 |
from diffusers import StableDiffusionInpaintPipeline
|
5 |
-
from PIL import Image
|
6 |
-
from
|
7 |
-
from huggingface_hub import hf_hub_download
|
8 |
import spaces
|
9 |
|
10 |
-
# Device configuration
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
|
13 |
-
#
|
14 |
-
checkpoint_path = hf_hub_download(
|
15 |
-
repo_id="Munaf1987/sam",
|
16 |
-
filename="sam_vit_h_4b8939.pth",
|
17 |
-
)
|
18 |
-
|
19 |
-
# Step 2: Load SAM model
|
20 |
-
model_type = "vit_h"
|
21 |
-
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
|
22 |
-
sam.to(device)
|
23 |
-
predictor = SamPredictor(sam)
|
24 |
-
|
25 |
-
# Step 3: Load Stable Diffusion Inpainting Pipeline
|
26 |
-
sta_diff_model = "stabilityai/stable-diffusion-2-inpainting"
|
27 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
28 |
-
|
29 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
30 |
-
)
|
31 |
-
pipe = pipe.to(device)
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
|
37 |
@spaces.GPU
|
38 |
-
def
|
39 |
-
|
40 |
-
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
input_labels = np.ones(input_points.shape[0])
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
point_labels=input_labels,
|
49 |
-
multimask_output=False,
|
50 |
-
)
|
51 |
|
52 |
-
|
53 |
-
mask_image = Image.fromarray(mask.astype(np.uint8)).convert("L")
|
54 |
-
generated_mask = mask_image
|
55 |
|
56 |
-
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
|
|
65 |
|
|
|
66 |
output = pipe(
|
67 |
prompt=prompt,
|
68 |
image=input_image,
|
69 |
-
mask_image=
|
70 |
).images[0]
|
71 |
|
72 |
return output
|
73 |
|
74 |
# Gradio UI
|
75 |
with gr.Blocks() as demo:
|
76 |
-
gr.Markdown("##
|
77 |
|
78 |
with gr.Row():
|
79 |
-
input_image = gr.Image(type="pil", label="Input Image"
|
80 |
-
mask_display = gr.Image(type="pil", label="Generated Mask")
|
81 |
output_image = gr.Image(type="pil", label="Output Image")
|
82 |
|
83 |
-
prompt_text = gr.Textbox(label="Prompt", placeholder="
|
84 |
submit = gr.Button("Submit")
|
85 |
|
86 |
-
|
87 |
-
submit.click(inpaint, inputs=[input_image, prompt_text], outputs=output_image)
|
88 |
|
89 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
from diffusers import StableDiffusionInpaintPipeline
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
|
|
7 |
import spaces
|
8 |
|
|
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
11 |
+
# Load the Stable Diffusion Inpainting model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
13 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
14 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
15 |
+
).to(device)
|
|
|
16 |
|
17 |
+
# Load the DETR object detection model
|
18 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
19 |
+
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
|
20 |
|
21 |
@spaces.GPU
|
22 |
+
def detect_and_remove(input_image, prompt):
|
23 |
+
if input_image is None or prompt == "":
|
24 |
+
return None
|
25 |
|
26 |
+
image_np = np.array(input_image)
|
27 |
+
inputs = processor(images=input_image, return_tensors="pt").to(device)
|
|
|
28 |
|
29 |
+
outputs = detector(**inputs)
|
30 |
+
target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
|
|
|
|
|
|
|
31 |
|
32 |
+
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
|
|
|
|
33 |
|
34 |
+
mask = Image.new("L", input_image.size, 0)
|
35 |
+
draw = ImageDraw.Draw(mask)
|
36 |
|
37 |
+
# Draw boxes for "person" class only
|
38 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
39 |
+
if detector.config.id2label[label.item()] == "person":
|
40 |
+
box = [int(i) for i in box.tolist()]
|
41 |
+
draw.rectangle(box, fill=255)
|
42 |
|
43 |
+
if np.array(mask).sum() == 0:
|
44 |
+
return "No human detected."
|
45 |
|
46 |
+
# Inpainting
|
47 |
output = pipe(
|
48 |
prompt=prompt,
|
49 |
image=input_image,
|
50 |
+
mask_image=mask
|
51 |
).images[0]
|
52 |
|
53 |
return output
|
54 |
|
55 |
# Gradio UI
|
56 |
with gr.Blocks() as demo:
|
57 |
+
gr.Markdown("## Automatic Human Removal and Inpainting")
|
58 |
|
59 |
with gr.Row():
|
60 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
|
|
61 |
output_image = gr.Image(type="pil", label="Output Image")
|
62 |
|
63 |
+
prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background")
|
64 |
submit = gr.Button("Submit")
|
65 |
|
66 |
+
submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image)
|
|
|
67 |
|
68 |
demo.launch()
|