linoyts's picture
linoyts HF staff
Update app.py
e6c2d69 verified
raw
history blame
5.42 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev" ,
torch_dtype=torch.bfloat16
).to("cuda")
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev",
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
text_encoder_2=pipe.text_encoder_2,
tokenizer_2=pipe.tokenizer_2,
torch_dtype=torch.bfloat16
).to("cuda")
@spaces.GPU
def infer(control_image, prompt, image_2, prompt_2, reference_scale= 0.03 ,
seed=42, randomize_seed=False, width=1024, height=1024,
guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if image_2 is not None:
pipe_prior_output = pipe_prior_redux([control_image, image_2], prompt=[prompt, prompt_2])
else:
pipe_prior_output = pipe_prior_redux(control_image, prompt=prompt)
cond_size = 729
hidden_size = 4096
max_sequence_length = 512
full_attention_size = max_sequence_length + hidden_size + cond_size
attention_mask = torch.zeros(
(full_attention_size, full_attention_size), device="cuda", dtype=torch.bfloat16
)
bias = torch.log(
torch.tensor(reference_scale, dtype=torch.bfloat16, device="cuda").clamp(min=1e-5, max=1)
)
attention_mask[:, max_sequence_length : max_sequence_length + cond_size] = bias
joint_attention_kwargs=dict(attention_mask=attention_mask)
images = pipe(
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed),
joint_attention_kwargs=joint_attention_kwargs,
**pipe_prior_output,
).images[0]
return images, seed
css="""
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 Redux [dev]
An adapter for FLUX [dev] to create image variations
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Image to create variations", type="pil")
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
reference_scale = gr.Slider(
label="Masking Scale",
minimum=0.01,
maximum=0.08,
step=0.001,
value=0.03,
)
run_button = gr.Button("Run")
with gr.Column():
image_2 = gr.Image(label="2nd image to create interpolated variations", type="pil")
prompt_2 = gr.Text(
label="2nd Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click],
fn = infer,
inputs = [input_image, prompt, image_2, prompt_2, reference_scale, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs = [result, seed]
)
demo.launch()