|
import os |
|
import spaces |
|
import time |
|
import gradio as gr |
|
import torch |
|
from torch import Tensor, nn |
|
from PIL import Image |
|
from torchvision import transforms |
|
from dataclasses import dataclass |
|
import math |
|
from typing import Callable |
|
import random |
|
from tqdm import tqdm |
|
import bitsandbytes as bnb |
|
from bitsandbytes.nn.modules import Params4bit, QuantState |
|
from transformers import ( |
|
pipeline, |
|
CLIPTextModel, CLIPTokenizer, |
|
T5EncoderModel, T5Tokenizer |
|
) |
|
from diffusers import AutoencoderKL |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from einops import rearrange, repeat |
|
|
|
|
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
translator = pipeline( |
|
"translation", |
|
model="Helsinki-NLP/opus-mt-ko-en", |
|
framework="pt", |
|
from_tf=True, |
|
device=-1 |
|
) |
|
|
|
|
|
|
|
class HFEmbedder(nn.Module): |
|
def __init__(self, version: str, max_length: int, **hf_kwargs): |
|
super().__init__() |
|
self.is_clip = version.startswith("openai") |
|
self.max_length = max_length |
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" |
|
|
|
if self.is_clip: |
|
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) |
|
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) |
|
else: |
|
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) |
|
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) |
|
|
|
|
|
self.hf_module = self.hf_module.eval().requires_grad_(False) |
|
|
|
def forward(self, text: list[str]) -> Tensor: |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
outputs = self.hf_module( |
|
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), |
|
attention_mask=None, |
|
output_hidden_states=False, |
|
) |
|
return outputs[self.output_key] |
|
|
|
|
|
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(torch_device) |
|
clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(torch_device) |
|
ae = AutoencoderKL.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16 |
|
).to(torch_device) |
|
|
|
|
|
def functional_linear_4bits(x, weight, bias): |
|
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) |
|
return out.to(x) |
|
|
|
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: |
|
if state is None: |
|
return None |
|
device = device or state.absmax.device |
|
state2 = QuantState( |
|
absmax=state.state2.absmax.to(device), |
|
shape=state.state2.shape, |
|
code=state.state2.code.to(device), |
|
blocksize=state.state2.blocksize, |
|
quant_type=state.state2.quant_type, |
|
dtype=state.state2.dtype, |
|
) if state.nested else None |
|
return QuantState( |
|
absmax=state.absmax.to(device), |
|
shape=state.shape, |
|
code=state.code.to(device), |
|
blocksize=state.blocksize, |
|
quant_type=state.quant_type, |
|
dtype=state.dtype, |
|
offset=state.offset.to(device) if state.nested else None, |
|
state2=state2, |
|
) |
|
|
|
class ForgeParams4bit(Params4bit): |
|
def to(self, *args, **kwargs): |
|
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) |
|
if device is not None and device.type == "cuda" and not self.bnb_quantized: |
|
return self._quantize(device) |
|
new = ForgeParams4bit( |
|
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), |
|
requires_grad=self.requires_grad, |
|
quant_state=copy_quant_state(self.quant_state, device), |
|
compress_statistics=False, |
|
blocksize=self.blocksize, |
|
quant_type=self.quant_type, |
|
quant_storage=self.quant_storage, |
|
bnb_quantized=self.bnb_quantized, |
|
module=self.module |
|
) |
|
self.module.quant_state = new.quant_state |
|
self.data = new.data |
|
self.quant_state = new.quant_state |
|
return new |
|
|
|
class ForgeLoader4Bit(torch.nn.Module): |
|
def __init__(self, *, device, dtype, quant_type, **kwargs): |
|
super().__init__() |
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) |
|
self.weight = None |
|
self.quant_state = None |
|
self.bias = None |
|
self.quant_type = quant_type |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
|
missing_keys, unexpected_keys, error_msgs): |
|
qs_keys = {k[len(prefix + "weight."):] for k in state_dict if k.startswith(prefix + "weight.")} |
|
if any("bitsandbytes" in k for k in qs_keys): |
|
qs = {k: state_dict[prefix + "weight." + k] for k in qs_keys} |
|
self.weight = ForgeParams4bit.from_prequantized( |
|
data=state_dict[prefix + "weight"], |
|
quantized_stats=qs, |
|
requires_grad=False, |
|
device=torch.device('cuda'), |
|
module=self |
|
) |
|
self.quant_state = self.weight.quant_state |
|
if prefix + "bias" in state_dict: |
|
self.bias = torch.nn.Parameter(state_dict[prefix + "bias"].to(self.dummy)) |
|
del self.dummy |
|
else: |
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, |
|
missing_keys, unexpected_keys, error_msgs) |
|
|
|
class Linear(ForgeLoader4Bit): |
|
def __init__(self, *args, device=None, dtype=None, **kwargs): |
|
super().__init__(device=device, dtype=dtype, quant_type='nf4') |
|
def forward(self, x): |
|
self.weight.quant_state = self.quant_state |
|
if self.bias is not None and self.bias.dtype != x.dtype: |
|
self.bias.data = self.bias.data.to(x.dtype) |
|
return functional_linear_4bits(x, self.weight, self.bias) |
|
|
|
nn.Linear = Linear |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd = load_file( |
|
hf_hub_download( |
|
repo_id="lllyasviel/flux1-dev-bnb-nf4", |
|
filename="flux1-dev-bnb-nf4-v2.safetensors" |
|
) |
|
) |
|
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k} |
|
|
|
model = Flux().to(torch_device, dtype=torch.bfloat16) |
|
model.load_state_dict(sd) |
|
model_zero_init = False |
|
|
|
|
|
|
|
def get_image(image) -> torch.Tensor | None: |
|
if image is None: |
|
return None |
|
image = Image.fromarray(image).convert("RGB") |
|
tf = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda x: 2.0 * x - 1.0), |
|
]) |
|
return tf(image)[None, ...] |
|
|
|
def prepare(t5, clip, img, prompt): |
|
bs, c, h, w = img.shape |
|
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) |
|
if bs == 1 and isinstance(prompt, list): |
|
img = repeat(img, "1 ... -> bs ...", bs=len(prompt)) |
|
img_ids = torch.zeros(h//2, w//2, 3, device=img.device) |
|
img_ids[...,1] = torch.arange(h//2, device=img.device)[:,None] |
|
img_ids[...,2] = torch.arange(w//2, device=img.device)[None,:] |
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0]) |
|
|
|
txt = t5([prompt] if isinstance(prompt, str) else prompt) |
|
if txt.shape[0] == 1 and img.shape[0] > 1: |
|
txt = repeat(txt, "1 ... -> bs ...", bs=img.shape[0]) |
|
txt_ids = torch.zeros(txt.size(0), txt.size(1), 3, device=img.device) |
|
|
|
vec = clip([prompt] if isinstance(prompt, str) else prompt) |
|
if vec.shape[0] == 1 and img.shape[0] > 1: |
|
vec = repeat(vec, "1 ... -> bs ...", bs=img.shape[0]) |
|
|
|
return { |
|
"img": img, |
|
"img_ids": img_ids, |
|
"txt": txt, |
|
"txt_ids": txt_ids, |
|
"vec": vec, |
|
} |
|
|
|
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): |
|
timesteps = torch.linspace(1, 0, num_steps+1) |
|
if shift: |
|
mu = ((max_shift-base_shift)/(4096-256))*(image_seq_len) + (base_shift - (256*(max_shift-base_shift)/(4096-256))) |
|
timesteps = timesteps.exp().div((1/timesteps-1)**1 + mu) |
|
return timesteps.tolist() |
|
|
|
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): |
|
guidance_vec = torch.full((img.size(0),), guidance, device=img.device, dtype=img.dtype) |
|
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps)-1): |
|
t_vec = torch.full((img.size(0),), t_curr, device=img.device, dtype=img.dtype) |
|
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, |
|
y=vec, timesteps=t_vec, guidance=guidance_vec) |
|
img = img + (t_prev - t_curr) * pred |
|
return img |
|
|
|
@spaces.GPU |
|
@torch.no_grad() |
|
def generate_image( |
|
prompt, width, height, guidance, inference_steps, seed, |
|
do_img2img, init_image, image2image_strength, resize_img, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
|
|
if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt): |
|
translated = translator(prompt, max_length=512)[0]['translation_text'] |
|
prompt = translated |
|
|
|
|
|
if seed == 0: |
|
seed = random.randint(1, 1_000_000) |
|
|
|
global model_zero_init, model |
|
if not model_zero_init: |
|
model = model.to(torch_device) |
|
model_zero_init = True |
|
|
|
|
|
if do_img2img and init_image is not None: |
|
init_img = get_image(init_image) |
|
if resize_img: |
|
init_img = torch.nn.functional.interpolate(init_img, (height, width)) |
|
else: |
|
h0, w0 = init_img.shape[-2:] |
|
init_img = init_img[..., :16*(h0//16), :16*(w0//16)] |
|
height, width = init_img.shape[-2:] |
|
init_img = ae.encode(init_img.to(torch_device).to(torch.bfloat16)).latent_dist.sample() |
|
init_img = (init_img - ae.config.shift_factor) * ae.config.scaling_factor |
|
else: |
|
init_img = None |
|
|
|
|
|
generator = torch.Generator(device=str(torch_device)).manual_seed(seed) |
|
x = torch.randn( |
|
1, 16, 2*math.ceil(height/16), 2*math.ceil(width/16), |
|
device=torch_device, dtype=torch.bfloat16, generator=generator |
|
) |
|
|
|
timesteps = get_schedule(inference_steps, (x.shape[-1]*x.shape[-2])//4, shift=True) |
|
|
|
if do_img2img and init_img is not None: |
|
t_idx = int((1 - image2image_strength) * inference_steps) |
|
t = timesteps[t_idx] |
|
timesteps = timesteps[t_idx:] |
|
x = t * x + (1 - t) * init_img.to(x.dtype) |
|
|
|
inp = prepare(t5, clip, x, prompt) |
|
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) |
|
|
|
x = rearrange(x[:, inp["txt"].shape[1]:, ...].float(), "b (h w) (c ph pw) -> b c (h ph) (w pw)", |
|
h=math.ceil(height/16), w=math.ceil(width/16), ph=2, pw=2) |
|
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): |
|
x = (x / ae.config.scaling_factor) + ae.config.shift_factor |
|
x = ae.decode(x).sample |
|
|
|
x = x.clamp(-1,1) |
|
img = Image.fromarray((127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)).cpu().byte().numpy()) |
|
|
|
return img, seed |
|
|
|
|
|
|
|
css = """footer { visibility: hidden; }""" |
|
|
|
def create_demo(): |
|
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: |
|
gr.Markdown("# News! Multilingual version [https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual](https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt(ํ๊ธ ๊ฐ๋ฅ)", value="A cute and fluffy golden retriever puppy sitting upright...") |
|
width = gr.Slider(128,2048,64,label="Width",value=768) |
|
height= gr.Slider(128,2048,64,label="Height",value=768) |
|
guidance = gr.Slider(1.0,5.0,0.1,label="Guidance",value=3.5) |
|
steps = gr.Slider(1,30,1,label="Inference steps",value=30) |
|
seed = gr.Number(label="Seed",precision=0) |
|
do_i2i = gr.Checkbox(label="Image to Image",value=False) |
|
init_img = gr.Image(label="Input Image", visible=False) |
|
strength = gr.Slider(0.0,1.0,0.01,label="Noising strength",value=0.8,visible=False) |
|
resize = gr.Checkbox(label="Resize image",value=True,visible=False) |
|
btn = gr.Button("Generate") |
|
with gr.Column(): |
|
out_img = gr.Image(label="Generated Image") |
|
out_seed = gr.Text(label="Used Seed") |
|
|
|
do_i2i.change( |
|
fn=lambda x: [gr.update(visible=x)]*3, |
|
inputs=[do_i2i], |
|
outputs=[init_img, strength, resize] |
|
) |
|
btn.click( |
|
fn=generate_image, |
|
inputs=[prompt, width, height, guidance, steps, seed, do_i2i, init_img, strength, resize], |
|
outputs=[out_img, out_seed] |
|
) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
create_demo().launch() |
|
|