Spaces:
Running
Running
import imageio | |
import numpy as np | |
from PIL import Image | |
import torch | |
from .controlnet_flux import FluxControlNetModel | |
from .transformer_flux import FluxTransformer2DModel | |
from .pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device for I2I: {device}") | |
# Load the inpainting pipeline | |
def resize_image(image, height, width): | |
"""Resize image tensor to the desired height and width.""" | |
return torch.nn.functional.interpolate(image, size=(height, width), mode='nearest') | |
def dummy(img): | |
"""Save the composite image and generate a mask from the alpha channel.""" | |
imageio.imwrite("output_image.png", img["composite"]) | |
# Extract alpha channel from the first layer to create the mask | |
alpha_channel = img["layers"][0][:, :, 3] | |
mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8) | |
return img["background"], mask | |
def I2I(prompt, image, width=1024, height=1024, guidance_scale=8.0, num_inference_steps=20, strength=0.99): | |
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) | |
transformer = FluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16 | |
) | |
pipe = FluxControlNetInpaintingPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
controlnet=controlnet, | |
transformer=transformer, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe.transformer.to(torch.bfloat16) | |
pipe.controlnet.to(torch.bfloat16) | |
pipe.set_attn_processor(FluxAttnProcessor2_0()) | |
img_url, mask = dummy(image) | |
# Resize image and mask to the target dimensions (height x width) | |
img_url = Image.fromarray(img_url, mode="RGB").resize((width, height)) | |
mask_url = Image.fromarray(mask,mode="L").resize((width, height)) | |
# Make sure both image and mask are converted into correct tensors | |
generator = torch.Generator(device=device).manual_seed(0) | |
# Generate the inpainted image | |
result = pipe( | |
prompt=prompt, | |
height=size[1], | |
width=size[0], | |
control_image=image, | |
control_mask=mask, | |
num_inference_steps=28, | |
generator=generator, | |
controlnet_conditioning_scale=0.9, | |
guidance_scale=3.5, | |
negative_prompt="", | |
true_guidance_scale=3.5 | |
).images[0] | |
return result |