Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,359 Bytes
a0e3aec dce4ef7 ebd2fe3 a0e3aec ba77185 a0e3aec dce4ef7 a0e3aec dce4ef7 a0e3aec 03e6b18 a0e3aec dce4ef7 a0e3aec f763f30 a0e3aec 83cf50d a0e3aec ebd2fe3 6f0d568 1010ff1 a0e3aec 03e6b18 a0e3aec 03e6b18 aefbf78 03e6b18 aefbf78 f9dbb27 03e6b18 aefbf78 03e6b18 aefbf78 03e6b18 aefbf78 03e6b18 aefbf78 03e6b18 1010ff1 03e6b18 6f0d568 03e6b18 dce4ef7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import torch
import spaces
import safetensors
import gradio as gr
from PIL import Image
from loguru import logger
from torchvision import transforms
from huggingface_hub import hf_hub_download, login
from diffusers import FluxPipeline, FluxTransformer2DModel
from projection import ImageEncoder
from transformer_flux_custom import FluxTransformer2DModel as FluxTransformer2DModelWithIP
model_config = './config.json'
pretrained_model_name = 'black-forest-labs/FLUX.1-dev'
adapter_path = 'model-v0.2.safetensors'
adapter_repo_id = "ashen0209/Flux-Character-Consitancy"
conditioner_base_model = 'eva02_large_patch14_448.mim_in22k_ft_in1k'
conditioner_layer_num = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
output_dim = 4096
logger.info(f"pretrained_model_name: {pretrained_model_name}, adapter_repo_id: {adapter_repo_id}, adapter_path: {adapter_path}, conditioner_layer: {conditioner_layer_num}, output_dim {output_dim}, device: {device}")
logger.info("init model")
model = FluxTransformer2DModelWithIP.from_config(model_config, torch_dtype=torch.bfloat16) # type: ignore
logger.info("load model")
copy = FluxTransformer2DModel.from_pretrained(pretrained_model_name, subfolder='transformer', torch_dtype=torch.bfloat16)
model.load_state_dict(copy.state_dict(), strict=False)
del copy
logger.info("load proj")
extra_embedder = ImageEncoder(output_dim, layer_num=conditioner_layer_num, seq_len=2, device=device, base_model=conditioner_base_model).to(device=device, dtype=torch.bfloat16)
logger.info("load pipe")
pipe = FluxPipeline.from_pretrained(pretrained_model_name, transformer=model, torch_dtype=torch.bfloat16)
pipe.to(dtype=torch.bfloat16, device=device)
logger.info("download adapter")
login(token=os.environ['HF_TOKEN'])
file_path = hf_hub_download(repo_id=adapter_repo_id, filename=adapter_path)
logger.info("load adapter")
state_dict = safetensors.torch.load_file(file_path)
state_dict = {'.'.join(k.split('.')[1:]): state_dict[k] for k in state_dict.keys()}
diff = model.load_state_dict(state_dict, strict=False)
diff = extra_embedder.load_state_dict(state_dict, strict=False)
IMAGE_PROCESS_TRANSFORM = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.276])
])
@spaces.GPU
def generate_image(ref_image, prompt="", height=512, width=512, ref_image2=None, num_steps=25, guidance_scale=3.5, seed=0, ip_scale=1.0):
print(f"ref_image: {ref_image.size if ref_image is not None else None}, "
f"ref_image2: {ref_image2.size if ref_image2 is not None else None}, "
f"prompt: {prompt}, height: {height}, width: {width}, num_steps: {num_steps}, guidance_scale: {guidance_scale}, ip_scale: {ip_scale}")
with torch.no_grad():
image_refs = map(torch.stack, [
[IMAGE_PROCESS_TRANSFORM(i) for i in [ref_image, ref_image2] if i is not None]
])
image_refs = [i.to(dtype=torch.bfloat16, device='cuda') for i in image_refs]
prompt_embeds, pooled_prompt_embeds, txt_ids = pipe.encode_prompt(prompt, prompt)
visual_prompt_embeds = extra_embedder(image_refs)
prompt_embeds_with_ref = torch.cat([prompt_embeds, visual_prompt_embeds], dim=1)
pipe.transformer.ip_scale = ip_scale
image = pipe(
prompt_embeds=prompt_embeds_with_ref,
pooled_prompt_embeds=pooled_prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
).images[0]
return image
examples = [
["assets/ref1.jpg", "A woman dancing in the dessert", 512, 768],
["assets/ref2.jpg", "A woman having dinner at a table", 512, 768],
["assets/ref3.jpg", "A woman walking in a park with trees and flowers", 512, 768],
["assets/ref4.jpg", "A woman run across a busy street", 512, 768],
]
with gr.Blocks() as demo:
# Top-level inputs that are always visible
with gr.Row():
gr.Markdown("""
## Character Consistancy Image Generation based on Flux
- The model can be downloaded at https://huggingface.co/ashen0209/Flux-Character-Consitancy
- The model is currently only good at generating consistent images of single human subject, multi-subjects and common object are not as satisfactory, but it will improved soon
""")
with gr.Row():
with gr.Column():
with gr.Row():
ref_image = gr.Image(type="pil", label="Upload Reference Subject Image", width=300)
ref_image2 = gr.Image(type="pil", label="[Optional] complement image or additional image from different category", width=200)
description = gr.Textbox(lines=2, placeholder="Describe the desired contents", label="Description Text")
generate_btn = gr.Button("Generate Image")
# Advanced options hidden inside an accordion (click to expand)
with gr.Accordion("Advanced Options", open=False):
height_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height")
width_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width")
steps_slider = gr.Slider(minimum=20, maximum=50, value=25, step=1, label="Number of Steps")
guidance_slider = gr.Slider(minimum=1.0, maximum=8.0, value=3.5, step=0.1, label="Guidance Scale")
ref_scale_slider = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="Reference Image Scale")
with gr.Column():
output = gr.Image(type="pil", label="Generated Image", )
# with gr.Row():
# with gr.Group():
# with gr.Row(equal_height=True):
# with gr.Column(scale=1, min_width=50, ):
# randomize_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
# with gr.Column(scale=3, min_width=100):
# seed_io = gr.Number(label="Seed (if not randomizing)", value=0, interactive=True, )
with gr.Row():
gr.Examples(
label='Click on following examples to load and try',
examples=examples,
inputs=[ref_image, description, height_slider, width_slider],
fn=generate_image,
outputs=output,
# example_labels=['Reference Subject', 'Additional Reference', 'Prompt', 'Height', 'Width'],
cache_examples=True,
cache_mode='lazy'
)
with gr.Row():
gr.Markdown("""
### Tips:
- Images with human subjects tend to perform better than other categories.
- Images where the subject occupies most of the frame with a clean, uncluttered background yield improved results.
- Including multiple subjects of the same category may cause blending issues.
""")
# When the button is clicked, pass all inputs to generate_image
generate_btn.click(
fn=generate_image,
inputs=[ref_image, description, height_slider, width_slider, ref_image2, steps_slider, guidance_slider, ref_scale_slider],
outputs=output,
)
if __name__ == "__main__":
demo.launch()
|