|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image, ImageDraw |
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
import spaces |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"SG161222/RealVisXL_V4.0", |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
use_safetensors=True |
|
).to(device) |
|
|
|
|
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") |
|
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device) |
|
|
|
@spaces.GPU |
|
def detect_and_replace(input_image, prompt, negative_prompt=""): |
|
if input_image is None or prompt == "": |
|
return None |
|
|
|
image_np = np.array(input_image) |
|
inputs = processor(images=input_image, return_tensors="pt").to(device) |
|
|
|
outputs = detector(**inputs) |
|
target_sizes = torch.tensor([image_np.shape[:2]]).to(device) |
|
|
|
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] |
|
|
|
mask = Image.new("L", input_image.size, 0) |
|
draw = ImageDraw.Draw(mask) |
|
|
|
boxes = [] |
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
|
if detector.config.id2label[label.item()] == "person": |
|
box = [int(i) for i in box.tolist()] |
|
boxes.append(box) |
|
draw.rectangle(box, fill=255) |
|
|
|
if not boxes: |
|
return "No human detected." |
|
|
|
output_image = input_image.copy() |
|
|
|
for box in boxes: |
|
x1, y1, x2, y2 = box |
|
width, height = x2 - x1, y2 - y1 |
|
|
|
|
|
generated_image = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=512, |
|
height=768, |
|
guidance_scale=7.5, |
|
num_inference_steps=30, |
|
output_type="pil" |
|
).images[0] |
|
|
|
|
|
resized_generated = generated_image.resize((width, height)) |
|
|
|
|
|
output_image.paste(resized_generated, (x1, y1)) |
|
|
|
return output_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Replace Bride and Groom with Imaginary Realistic Characters") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
output_image = gr.Image(type="pil", label="Output Image") |
|
|
|
prompt_text = gr.Textbox(label="Prompt", placeholder="Describe the imaginary bride/groom") |
|
negative_prompt_text = gr.Textbox(label="Negative Prompt", placeholder="Optional negative prompt") |
|
submit = gr.Button("Submit") |
|
|
|
submit.click(detect_and_replace, inputs=[input_image, prompt_text, negative_prompt_text], outputs=output_image) |
|
|
|
demo.launch() |
|
|