Spaces:
Running
Running
File size: 2,507 Bytes
8d9a1a3 065795a 8d9a1a3 2893544 065795a f963ba1 065795a 2893544 8d9a1a3 065795a 8d9a1a3 065795a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import imageio
import numpy as np
from PIL import Image
import torch
from .controlnet_flux import FluxControlNetModel
from .transformer_flux import FluxTransformer2DModel
from .pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device for I2I: {device}")
# Load the inpainting pipeline
def resize_image(image, height, width):
"""Resize image tensor to the desired height and width."""
return torch.nn.functional.interpolate(image, size=(height, width), mode='nearest')
def dummy(img):
"""Save the composite image and generate a mask from the alpha channel."""
imageio.imwrite("output_image.png", img["composite"])
# Extract alpha channel from the first layer to create the mask
alpha_channel = img["layers"][0][:, :, 3]
mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)
return img["background"], mask
def I2I(prompt, image, width=1024, height=1024, guidance_scale=8.0, num_inference_steps=20, strength=0.99):
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
)
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
transformer=transformer,
torch_dtype=torch.bfloat16
).to(device)
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)
pipe.set_attn_processor(FluxAttnProcessor2_0())
img_url, mask = dummy(image)
# Resize image and mask to the target dimensions (height x width)
img_url = Image.fromarray(img_url, mode="RGB").resize((width, height))
mask_url = Image.fromarray(mask,mode="L").resize((width, height))
# Make sure both image and mask are converted into correct tensors
generator = torch.Generator(device=device).manual_seed(0)
# Generate the inpainted image
result = pipe(
prompt=prompt,
height=size[1],
width=size[0],
control_image=image,
control_mask=mask,
num_inference_steps=28,
generator=generator,
controlnet_conditioning_scale=0.9,
guidance_scale=3.5,
negative_prompt="",
true_guidance_scale=3.5
).images[0]
return result |