zs38's picture
fix
b155f10
raw
history blame
7.01 kB
import os
import torch
import spaces
import safetensors
import gradio as gr
from PIL import Image
from loguru import logger
from torchvision import transforms
from huggingface_hub import hf_hub_download, login
from diffusers import FluxPipeline, FluxTransformer2DModel
from projection import ImageEncoder
from transformer_flux_custom import FluxTransformer2DModel as FluxTransformer2DModelWithIP
model_config = './config.json'
pretrained_model_name = 'black-forest-labs/FLUX.1-dev'
adapter_path = 'model-v0.2.safetensors'
adapter_repo_id = "ashen0209/Flux-Character-Consitancy"
conditioner_base_model = 'eva02_large_patch14_448.mim_in22k_ft_in1k'
conditioner_layer_num = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
output_dim = 4096
logger.info(f"pretrained_model_name: {pretrained_model_name}, adapter_repo_id: {adapter_repo_id}, adapter_path: {adapter_path}, conditioner_layer: {conditioner_layer_num}, output_dim {output_dim}, device: {device}")
logger.info("init model")
model = FluxTransformer2DModelWithIP.from_config(model_config, torch_dtype=torch.bfloat16) # type: ignore
logger.info("load model")
copy = FluxTransformer2DModel.from_pretrained(pretrained_model_name, subfolder='transformer', torch_dtype=torch.bfloat16)
model.load_state_dict(copy.state_dict(), strict=False)
del copy
logger.info("load proj")
extra_embedder = ImageEncoder(output_dim, layer_num=conditioner_layer_num, seq_len=2, device=device, base_model=conditioner_base_model).to(device=device, dtype=torch.bfloat16)
logger.info("load pipe")
pipe = FluxPipeline.from_pretrained(pretrained_model_name, transformer=model, torch_dtype=torch.bfloat16)
pipe.to(dtype=torch.bfloat16, device=device)
logger.info("download adapter")
login(token=os.environ['HF_TOKEN'])
file_path = hf_hub_download(repo_id=adapter_repo_id, filename=adapter_path)
logger.info("load adapter")
state_dict = safetensors.torch.load_file(file_path)
state_dict = {'.'.join(k.split('.')[1:]): state_dict[k] for k in state_dict.keys()}
diff = model.load_state_dict(state_dict, strict=False)
diff = extra_embedder.load_state_dict(state_dict, strict=False)
IMAGE_PROCESS_TRANSFORM = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.276])
])
@spaces.GPU
def generate_image(ref_image, ref_image2, prompt, height=512, width=512, num_steps=25, guidance_scale=3.5, seed=0, ip_scale=1.0):
print(f"ref_image: {ref_image.size}, prompt: {prompt}, height: {height}, width: {width}, num_steps: {num_steps}, guidance_scale: {guidance_scale}, ip_scale: {ip_scale}")
with torch.no_grad():
image_refs = map(torch.stack, [
[IMAGE_PROCESS_TRANSFORM(i) for i in [ref_image, ref_image2] if i is not None]
])
image_refs = [i.to(dtype=torch.bfloat16, device='cuda') for i in image_refs]
prompt_embeds, pooled_prompt_embeds, txt_ids = pipe.encode_prompt(prompt, prompt)
visual_prompt_embeds = extra_embedder(image_refs)
prompt_embeds_with_ref = torch.cat([prompt_embeds, visual_prompt_embeds], dim=1)
pipe.transformer.ip_scale = ip_scale
image = pipe(
prompt_embeds=prompt_embeds_with_ref,
pooled_prompt_embeds=pooled_prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
).images[0]
return image
examples = [
["assets/ref1.jpg", None, "A woman dancing in the dessert", 512, 768],
["assets/ref1.jpg", "assets/ref_cat.jpg", "A woman holding a cat above her head", 768, 512],
["assets/ref2.jpg", None, "A woman sitting on the beach near the sea", 512, 768],
]
with gr.Blocks() as demo:
# Top-level inputs that are always visible
with gr.Row():
gr.Markdown("""
## Character Consistancy Image Generation based on Flux
""")
with gr.Row():
with gr.Column():
with gr.Row():
ref_image = gr.Image(type="pil", label="Upload Reference Subject Image", width=300)
ref_image2 = gr.Image(type="pil", label="[Optional] compliment or different category", width=200)
description = gr.Textbox(lines=2, placeholder="Describe the desired contents", label="Description Text")
generate_btn = gr.Button("Generate Image")
# Advanced options hidden inside an accordion (click to expand)
with gr.Accordion("Advanced Options", open=False):
height_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height")
width_slider = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width")
steps_slider = gr.Slider(minimum=20, maximum=50, value=25, step=1, label="Number of Steps")
guidance_slider = gr.Slider(minimum=1.0, maximum=8.0, value=3.5, step=0.1, label="Guidance Scale")
ref_scale_slider = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="Reference Image Scale")
with gr.Column():
output = gr.Image(type="pil", label="Generated Image")
# with gr.Row():
with gr.Group():
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=50, ):
randomize_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Column(scale=3, min_width=100):
seed_io = gr.Number(label="Seed (if not randomizing)", value=0, interactive=True, )
with gr.Row():
gr.Examples(
label='Click on following examples to load and try',
examples=examples,
inputs=[ref_image, ref_image2, description, height_slider, width_slider],
fn=generate_image,
outputs=output,
# example_labels=['Reference Subject', 'Additional Reference', 'Prompt', 'Height', 'Width'],
cache_examples=True,
cache_mode='lazy'
)
with gr.Row():
gr.Markdown("""
### Tips:
- Images with human subjects tend to perform better than other categories.
- Images where the subject occupies most of the frame with a clean, uncluttered background yield improved results.
- Including multiple subjects of the same category may cause blending issues (this is being improved).
- Despite these factors, most image inputs still produce reasonable and satisfactory results.
""")
# When the button is clicked, pass all inputs to generate_image
generate_btn.click(
fn=generate_image,
inputs=[ref_image, ref_image2, description, height_slider, width_slider, steps_slider, guidance_slider, ref_scale_slider],
outputs=output,
)
if __name__ == "__main__":
demo.launch()