Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from transformers import SamModel, SamProcessor | |
from diffusers import AutoPipelineForInpainting | |
import torch | |
# Check if GPU is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Model and Processor setup | |
model_name = "facebook/sam-vit-huge" | |
model = SamModel.from_pretrained(model_name).to(device) | |
processor = SamProcessor.from_pretrained(model_name) | |
def mask_to_rgb(mask): | |
""" Convert binary mask to RGB with transparency for the background. """ | |
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) | |
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green mask with some transparency | |
return bg_transparent | |
def get_processed_inputs(image, points): | |
""" Process the input image and points using SAM model and processor. """ | |
inputs = processor(image, input_points=points, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
best_mask = masks[0][0][outputs.iou_scores.argmax()] | |
return ~best_mask.cpu().numpy() | |
def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7): | |
""" Inpaint the masked area in the image using a text prompt and an inpainting pipeline. """ | |
mask_image = Image.fromarray(input_mask) | |
rand_gen = torch.manual_seed(seed) | |
pipeline = AutoPipelineForInpainting.from_pretrained( | |
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
if device == "cpu": | |
pipeline.enable_model_cpu_offload() | |
image = pipeline( | |
prompt=prompt, | |
image=raw_image, | |
mask_image=mask_image, | |
guidance_scale=cfgs, | |
negative_prompt=negative_prompt, | |
generator=rand_gen | |
).images[0] | |
return image | |
def gradio_interface(image, points_json, positive_prompt, negative_prompt): | |
""" Gradio interface function to handle image, points for segmentation, and prompts. """ | |
points = [[(point['x'], point['y']) for point in stroke['points']] for stroke in points_json] | |
raw_image = Image.fromarray(image).convert("RGB").resize((512, 512)) | |
mask = get_processed_inputs(raw_image, points) | |
processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt) | |
return processed_image, mask_to_rgb(mask) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Image(type="numpy", label="Input Image"), | |
gr.Image(type="json", label="Click to select points", tool="sketch", brush_radius=1, shape=(512, 512)), | |
gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"), | |
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here") | |
], | |
outputs=[ | |
gr.Image(label="Inpainted Image"), | |
gr.Image(label="Segmentation Mask") | |
], | |
title="Interactive Image Inpainting", | |
description="Click on the image to select points for segmentation, provide prompts, and see the inpainted result." | |
) | |
iface.launch(share=True) |