replacebg / app.py
Munaf1987's picture
Update app.py
855a558 verified
raw
history blame
2.97 kB
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from PIL import Image, ImageDraw
from transformers import DetrImageProcessor, DetrForObjectDetection
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Inpainting Pipeline
pipe = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0", # βœ… Realistic human generation model
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True
).to(device)
# Load DETR for human detection
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
@spaces.GPU
def detect_and_replace(input_image, prompt, negative_prompt=""):
if input_image is None or prompt == "":
return None
image_np = np.array(input_image)
inputs = processor(images=input_image, return_tensors="pt").to(device)
outputs = detector(**inputs)
target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
mask = Image.new("L", input_image.size, 0)
draw = ImageDraw.Draw(mask)
boxes = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if detector.config.id2label[label.item()] == "person":
box = [int(i) for i in box.tolist()]
boxes.append(box)
draw.rectangle(box, fill=255)
if not boxes:
return "No human detected."
output_image = input_image.copy()
for box in boxes:
x1, y1, x2, y2 = box
width, height = x2 - x1, y2 - y1
# Generate imaginary person image
generated_image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
guidance_scale=7.5,
num_inference_steps=30,
output_type="pil"
).images[0]
# Resize generated image to fit the detected box
resized_generated = generated_image.resize((width, height))
# Paste the generated image on the original image at the detected location
output_image.paste(resized_generated, (x1, y1))
return output_image
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Replace Bride and Groom with Imaginary Realistic Characters")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(type="pil", label="Output Image")
prompt_text = gr.Textbox(label="Prompt", placeholder="Describe the imaginary bride/groom")
negative_prompt_text = gr.Textbox(label="Negative Prompt", placeholder="Optional negative prompt")
submit = gr.Button("Submit")
submit.click(detect_and_replace, inputs=[input_image, prompt_text, negative_prompt_text], outputs=output_image)
demo.launch()