File size: 7,359 Bytes
a0e3aec
dce4ef7
 
ebd2fe3
a0e3aec
 
 
ba77185
a0e3aec
 
 
dce4ef7
a0e3aec
 
dce4ef7
 
a0e3aec
 
03e6b18
a0e3aec
dce4ef7
a0e3aec
 
 
 
f763f30
a0e3aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83cf50d
a0e3aec
 
 
 
 
 
 
 
 
 
 
ebd2fe3
6f0d568
1010ff1
 
 
a0e3aec
 
03e6b18
a0e3aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03e6b18
 
 
 
 
aefbf78
 
 
 
03e6b18
 
 
 
 
 
 
aefbf78
f9dbb27
 
03e6b18
 
 
 
 
 
aefbf78
03e6b18
 
 
 
 
 
 
 
 
 
 
 
aefbf78
03e6b18
aefbf78
 
 
 
 
 
03e6b18
 
 
 
 
aefbf78
03e6b18
 
 
 
 
 
 
 
 
 
 
 
1010ff1
03e6b18
 
 
 
6f0d568
03e6b18
 
 
 
 
 
 
dce4ef7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os

import torch
import spaces
import safetensors
import gradio as gr
from PIL import Image
from loguru import logger
from torchvision import transforms
from huggingface_hub import hf_hub_download, login
from diffusers import FluxPipeline, FluxTransformer2DModel

from projection import ImageEncoder
from transformer_flux_custom import FluxTransformer2DModel as FluxTransformer2DModelWithIP


model_config = './config.json'
pretrained_model_name = 'black-forest-labs/FLUX.1-dev'
adapter_path = 'model-v0.2.safetensors'
adapter_repo_id = "ashen0209/Flux-Character-Consitancy"

conditioner_base_model = 'eva02_large_patch14_448.mim_in22k_ft_in1k'
conditioner_layer_num = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
output_dim = 4096
logger.info(f"pretrained_model_name: {pretrained_model_name}, adapter_repo_id: {adapter_repo_id}, adapter_path: {adapter_path}, conditioner_layer: {conditioner_layer_num}, output_dim {output_dim}, device: {device}")

logger.info("init model")
model = FluxTransformer2DModelWithIP.from_config(model_config, torch_dtype=torch.bfloat16) # type: ignore
logger.info("load model")
copy = FluxTransformer2DModel.from_pretrained(pretrained_model_name, subfolder='transformer', torch_dtype=torch.bfloat16)
model.load_state_dict(copy.state_dict(), strict=False)
del copy

logger.info("load proj")
extra_embedder = ImageEncoder(output_dim, layer_num=conditioner_layer_num, seq_len=2, device=device, base_model=conditioner_base_model).to(device=device, dtype=torch.bfloat16)

logger.info("load pipe")
pipe = FluxPipeline.from_pretrained(pretrained_model_name, transformer=model, torch_dtype=torch.bfloat16)
pipe.to(dtype=torch.bfloat16, device=device)

logger.info("download adapter")
login(token=os.environ['HF_TOKEN'])
file_path = hf_hub_download(repo_id=adapter_repo_id, filename=adapter_path)

logger.info("load adapter")
state_dict = safetensors.torch.load_file(file_path)
state_dict = {'.'.join(k.split('.')[1:]): state_dict[k] for k in state_dict.keys()}
diff = model.load_state_dict(state_dict, strict=False)
diff = extra_embedder.load_state_dict(state_dict, strict=False)


IMAGE_PROCESS_TRANSFORM = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.276])
])

@spaces.GPU
def generate_image(ref_image, prompt="", height=512, width=512, ref_image2=None, num_steps=25, guidance_scale=3.5, seed=0, ip_scale=1.0):
    print(f"ref_image: {ref_image.size if ref_image is not None else None}, "
          f"ref_image2: {ref_image2.size if ref_image2 is not None else None}, "
          f"prompt: {prompt}, height: {height}, width: {width}, num_steps: {num_steps}, guidance_scale: {guidance_scale}, ip_scale: {ip_scale}")
    with torch.no_grad():
        image_refs = map(torch.stack, [
            [IMAGE_PROCESS_TRANSFORM(i) for i in [ref_image, ref_image2] if i is not None]
        ])    
        image_refs = [i.to(dtype=torch.bfloat16, device='cuda') for i in image_refs]
        prompt_embeds, pooled_prompt_embeds, txt_ids = pipe.encode_prompt(prompt, prompt)
        visual_prompt_embeds = extra_embedder(image_refs)
        prompt_embeds_with_ref = torch.cat([prompt_embeds, visual_prompt_embeds], dim=1)
        pipe.transformer.ip_scale = ip_scale
        image = pipe(
            prompt_embeds=prompt_embeds_with_ref,
            pooled_prompt_embeds=pooled_prompt_embeds,
            # negative_prompt_embeds=negative_prompt_embeds,
            # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            height=height,
            width=width,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
        ).images[0]
        return image    



examples = [
    ["assets/ref1.jpg", "A woman dancing in the dessert", 512, 768],
    ["assets/ref2.jpg", "A woman having dinner at a table", 512, 768],
    ["assets/ref3.jpg", "A woman walking in a park with trees and flowers", 512, 768],
    ["assets/ref4.jpg", "A woman run across a busy street", 512, 768],

]

with gr.Blocks() as demo:
    # Top-level inputs that are always visible
    with gr.Row():
        gr.Markdown("""
## Character Consistancy Image Generation based on Flux
- The model can be downloaded at https://huggingface.co/ashen0209/Flux-Character-Consitancy
- The model is currently only good at generating consistent images of single human subject, multi-subjects and common object are not as satisfactory, but it will improved soon
""")

    with gr.Row():
        with gr.Column():
            with gr.Row():
                ref_image = gr.Image(type="pil", label="Upload Reference Subject Image", width=300)
                ref_image2 = gr.Image(type="pil", label="[Optional] complement image or additional image from different category", width=200)
            description = gr.Textbox(lines=2, placeholder="Describe the desired contents", label="Description Text")
            generate_btn = gr.Button("Generate Image")

            # Advanced options hidden inside an accordion (click to expand)
            with gr.Accordion("Advanced Options", open=False):
                height_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height")
                width_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width")
                steps_slider = gr.Slider(minimum=20, maximum=50, value=25, step=1, label="Number of Steps")
                guidance_slider = gr.Slider(minimum=1.0, maximum=8.0, value=3.5, step=0.1, label="Guidance Scale")
                ref_scale_slider = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="Reference Image Scale")
    
        with gr.Column():
            output = gr.Image(type="pil", label="Generated Image", )
            # with gr.Row():
            # with gr.Group():
            #     with gr.Row(equal_height=True):
            #         with gr.Column(scale=1, min_width=50, ):
            #             randomize_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
            #         with gr.Column(scale=3, min_width=100):
            #             seed_io = gr.Number(label="Seed (if not randomizing)", value=0, interactive=True, )

    with gr.Row():
        gr.Examples(
            label='Click on following examples to load and try',
            examples=examples,
            inputs=[ref_image, description, height_slider, width_slider],
            fn=generate_image,
            outputs=output,
            # example_labels=['Reference Subject', 'Additional Reference', 'Prompt', 'Height', 'Width'],
            cache_examples=True, 
            cache_mode='lazy' 
        )
    
    with gr.Row():
        gr.Markdown("""
### Tips:
- Images with human subjects tend to perform better than other categories.
- Images where the subject occupies most of the frame with a clean, uncluttered background yield improved results.
- Including multiple subjects of the same category may cause blending issues.
""")    
    # When the button is clicked, pass all inputs to generate_image
    generate_btn.click(
        fn=generate_image,
        inputs=[ref_image, description, height_slider, width_slider, ref_image2, steps_slider, guidance_slider, ref_scale_slider],
        outputs=output,
    )



if __name__ == "__main__":
    demo.launch()