modeltest / test.py
yaseengoldfinchpc's picture
Git Push
43c5517
raw
history blame
2.91 kB
from diffusers import StableDiffusionInpaintPipeline
import torch
from PIL import Image
import os
def setup_model(model_path):
# Load the base pipeline
pipe = StableDiffusionInpaintPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
safety_checker=None
)
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# Enable memory optimizations
pipe.enable_attention_slicing()
return pipe
def prepare_images(image_path, mask_path=None):
# Load and prepare the original image
original_image = Image.open(image_path)
# Resize to a multiple of 8 (required by Stable Diffusion)
width, height = (dim - dim % 8 for dim in original_image.size)
original_image = original_image.resize((width, height))
if mask_path:
# Load existing mask if provided
mask_image = Image.open(mask_path)
mask_image = mask_image.resize((width, height))
mask_image = mask_image.convert("L")
else:
# Create a simple rectangular mask in the center
mask_image = Image.new("L", (width, height), 0)
mask_width = width // 3
mask_height = height // 3
x1 = (width - mask_width) // 2
y1 = (height - mask_height) // 2
x2 = x1 + mask_width
y2 = y1 + mask_height
for y in range(y1, y2):
for x in range(x1, x2):
mask_image.putpixel((x, y), 255)
return original_image, mask_image
def main():
# Setup paths using raw strings
model_path = "realisticVisionV60B1_v51VAE-inpainting.safetensors"
image_path = r"C:\Users\M. Y\Downloads\t2.png"
# First install accelerate if not already installed
try:
import accelerate
except ImportError:
print("Installing accelerate...")
os.system("pip install accelerate")
# Initialize model
print("Loading model...")
pipe = setup_model(model_path)
# Prepare images
print("Preparing images...")
original_image, mask_image = prepare_images(image_path)
# Save mask for verification
mask_image.save("generated_mask.png")
# Define your prompt
prompt = "a realistic photo of a beautiful garden"
negative_prompt = "blurry, low quality, distorted"
print("Performing inpainting...")
output_image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=original_image,
mask_image=mask_image,
num_inference_steps=30,
guidance_scale=7.5,
).images[0]
# Save the result
output_image.save("inpainted_result.png")
print("Inpainting completed! Check 'inpainted_result.png' for the result.")
if __name__ == "__main__":
main()