File size: 2,507 Bytes
8d9a1a3
 
 
 
065795a
 
 
 
8d9a1a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2893544
065795a
 
 
 
 
 
 
 
 
 
 
 
f963ba1
065795a
2893544
8d9a1a3
 
 
 
 
 
 
 
 
 
065795a
8d9a1a3
065795a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import imageio
import numpy as np
from PIL import Image
import torch
from .controlnet_flux import FluxControlNetModel
from .transformer_flux import FluxTransformer2DModel
from .pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device for I2I: {device}")

# Load the inpainting pipeline

def resize_image(image, height, width):
    """Resize image tensor to the desired height and width."""
    return torch.nn.functional.interpolate(image, size=(height, width), mode='nearest')


def dummy(img):
    """Save the composite image and generate a mask from the alpha channel."""
    imageio.imwrite("output_image.png", img["composite"])

    # Extract alpha channel from the first layer to create the mask
    alpha_channel = img["layers"][0][:, :, 3]
    mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)

    return img["background"], mask


def I2I(prompt, image, width=1024, height=1024, guidance_scale=8.0, num_inference_steps=20, strength=0.99):
    
    controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
    transformer = FluxTransformer2DModel.from_pretrained(
            "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
        )
    pipe = FluxControlNetInpaintingPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        controlnet=controlnet,
        transformer=transformer,
        torch_dtype=torch.bfloat16
    ).to(device)
    pipe.transformer.to(torch.bfloat16)
    pipe.controlnet.to(torch.bfloat16)
    pipe.set_attn_processor(FluxAttnProcessor2_0())


    img_url, mask = dummy(image)

    # Resize image and mask to the target dimensions (height x width)
    img_url = Image.fromarray(img_url, mode="RGB").resize((width, height))
    mask_url = Image.fromarray(mask,mode="L").resize((width, height))

    # Make sure both image and mask are converted into correct tensors
    generator = torch.Generator(device=device).manual_seed(0)

    # Generate the inpainted image
    result = pipe(
        prompt=prompt,
        height=size[1],
        width=size[0],
        control_image=image,
        control_mask=mask,
        num_inference_steps=28,
        generator=generator,
        controlnet_conditioning_scale=0.9,
        guidance_scale=3.5,
        negative_prompt="",
        true_guidance_scale=3.5
        ).images[0]

    return result