ahmetyaylalioglu's picture
Update app.py
e2a5dbf verified
raw
history blame
3.31 kB
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
import torch
# Check if GPU is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Model and Processor setup
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)
def mask_to_rgb(mask):
""" Convert binary mask to RGB with transparency for the background. """
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green mask with some transparency
return bg_transparent
def get_processed_inputs(image, annotation):
""" Process the input image and annotated drawing using SAM model and processor. """
mask = np.zeros(image.size, dtype=np.uint8)
mask[annotation[:,:,3] > 128] = 1 # Assume drawing is in alpha channel of RGBA
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
best_mask = masks[0][0][outputs.iou_scores.argmax()]
return ~best_mask.cpu().numpy()
def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
""" Inpaint the masked area in the image using a text prompt and an inpainting pipeline. """
mask_image = Image.fromarray(input_mask)
rand_gen = torch.manual_seed(seed)
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
if device == "cpu":
pipeline.enable_model_cpu_offload()
image = pipeline(
prompt=prompt,
image=raw_image,
mask_image=mask_image,
guidance_scale=cfgs,
negative_prompt=negative_prompt,
generator=rand_gen
).images[0]
return image
def gradio_interface(image, annotation, positive_prompt, negative_prompt):
""" Gradio interface function to handle image, annotated drawing, and prompts. """
raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
mask = get_processed_inputs(raw_image, annotation)
processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
return processed_image, mask_to_rgb(mask)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Image(tool="editor", label="Draw on the image", output="png", shape=(512, 512)),
gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
],
outputs=[
gr.Image(label="Inpainted Image"),
gr.Image(label="Segmentation Mask")
],
title="Interactive Image Inpainting",
description="Draw on the image to select areas for segmentation, provide prompts, and see the inpainted result."
)
iface.launch(share=True)