ahmetyaylalioglu's picture
Update app.py
9b2337b verified
raw
history blame
2.35 kB
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
import torch
# Force the model to use CPU regardless of the availability of CUDA
device = "cpu"
# Update model loading and processing to work on CPU
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)
def mask_to_rgb(mask):
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127]
return bg_transparent
def get_processed_inputs(image, points):
input_points = [[list(map(int, point.split(',')))] for point in points.split('|') if point]
inputs = processor(image, input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
best_mask = masks[0][0][outputs.iou_scores.argmax()]
return ~best_mask.cpu().numpy()
def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
mask_image = Image.fromarray(input_mask)
rand_gen = torch.manual_seed(seed)
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
image = pipeline(
prompt=prompt,
image=raw_image,
mask_image=mask_image,
guidance_scale=cfgs,
negative_prompt=negative_prompt,
generator=rand_gen
).images[0]
return image
# Gradio Interface with Click Events
def gradio_interface(image, points):
raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
mask = get_processed_inputs(raw_image, points)
processed_image = inpaint(raw_image, mask, "a car driving on Mars. Studio lights, 1970s", "artifacts, low quality, distortion")
return processed_image, mask_to_rgb(mask)
iface = gr.Interface(
fn=gradio_interface,
inputs=["image", gr.Image(shape=(512, 512), image_mode='RGB', source="canvas", tool="sketch")],
outputs=["image", "image"]
)
iface.launch(share=True)