import imageio import numpy as np from PIL import Image import torch from .controlnet_flux import FluxControlNetModel from .transformer_flux import FluxTransformer2DModel from .pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device for I2I: {device}") # Load the inpainting pipeline def resize_image(image, height, width): """Resize image tensor to the desired height and width.""" return torch.nn.functional.interpolate(image, size=(height, width), mode='nearest') def dummy(img): """Save the composite image and generate a mask from the alpha channel.""" imageio.imwrite("output_image.png", img["composite"]) # Extract alpha channel from the first layer to create the mask alpha_channel = img["layers"][0][:, :, 3] mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8) return img["background"], mask def I2I(prompt, image, width=1024, height=1024, guidance_scale=8.0, num_inference_steps=20, strength=0.99): controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16 ) pipe = FluxControlNetInpaintingPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=controlnet, transformer=transformer, torch_dtype=torch.bfloat16 ).to(device) pipe.transformer.to(torch.bfloat16) pipe.controlnet.to(torch.bfloat16) pipe.set_attn_processor(FluxAttnProcessor2_0()) img_url, mask = dummy(image) # Resize image and mask to the target dimensions (height x width) img_url = Image.fromarray(img_url, mode="RGB").resize((width, height)) mask_url = Image.fromarray(mask,mode="L").resize((width, height)) # Make sure both image and mask are converted into correct tensors generator = torch.Generator(device=device).manual_seed(0) # Generate the inpainted image result = pipe( prompt=prompt, height=size[1], width=size[0], control_image=image, control_mask=mask, num_inference_steps=28, generator=generator, controlnet_conditioning_scale=0.9, guidance_scale=3.5, negative_prompt="", true_guidance_scale=3.5 ).images[0] return result