Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import time | |
import gradio as gr | |
import torch | |
from torch import Tensor, nn | |
from PIL import Image | |
from torchvision import transforms | |
from dataclasses import dataclass | |
import math | |
from typing import Callable | |
import random | |
from tqdm import tqdm | |
import bitsandbytes as bnb | |
from bitsandbytes.nn.modules import Params4bit, QuantState | |
from transformers import ( | |
pipeline, | |
CLIPTextModel, CLIPTokenizer, | |
T5EncoderModel, T5Tokenizer | |
) | |
from diffusers import AutoencoderKL | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from einops import rearrange, repeat | |
# 1) ์ฅ์น(device) ์ค์ : GPU๊ฐ ์์ผ๋ฉด CUDA, ์์ผ๋ฉด CPU | |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 2) ๋ฒ์ญ ํ์ดํ๋ผ์ธ: TF ์ฒดํฌํฌ์ธํธ๋ PyTorch๋ก ๊ฐ์ ๋ก๋, CPU์์ ์คํ | |
translator = pipeline( | |
"translation", | |
model="Helsinki-NLP/opus-mt-ko-en", | |
framework="pt", | |
from_tf=True, | |
device=-1 | |
) | |
# ---------------- Encoders ---------------- | |
class HFEmbedder(nn.Module): | |
def __init__(self, version: str, max_length: int, **hf_kwargs): | |
super().__init__() | |
self.is_clip = version.startswith("openai") | |
self.max_length = max_length | |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" | |
if self.is_clip: | |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) | |
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) | |
else: | |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) | |
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) | |
# ํ๋ผ๋ฏธํฐ ๋๊ฒฐ | |
self.hf_module = self.hf_module.eval().requires_grad_(False) | |
def forward(self, text: list[str]) -> Tensor: | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
outputs = self.hf_module( | |
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), | |
attention_mask=None, | |
output_hidden_states=False, | |
) | |
return outputs[self.output_key] | |
# ์๋ฒ ๋์ VAE๋ฅผ ๋ชจ๋ torch_device๋ก ์ด๋ | |
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(torch_device) | |
clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(torch_device) | |
ae = AutoencoderKL.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16 | |
).to(torch_device) | |
# ---------------- NF4 ๋ก์ง (๋ณ๊ฒฝ ์์) ---------------- | |
def functional_linear_4bits(x, weight, bias): | |
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) | |
return out.to(x) | |
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: | |
if state is None: | |
return None | |
device = device or state.absmax.device | |
state2 = QuantState( | |
absmax=state.state2.absmax.to(device), | |
shape=state.state2.shape, | |
code=state.state2.code.to(device), | |
blocksize=state.state2.blocksize, | |
quant_type=state.state2.quant_type, | |
dtype=state.state2.dtype, | |
) if state.nested else None | |
return QuantState( | |
absmax=state.absmax.to(device), | |
shape=state.shape, | |
code=state.code.to(device), | |
blocksize=state.blocksize, | |
quant_type=state.quant_type, | |
dtype=state.dtype, | |
offset=state.offset.to(device) if state.nested else None, | |
state2=state2, | |
) | |
class ForgeParams4bit(Params4bit): | |
def to(self, *args, **kwargs): | |
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) | |
if device is not None and device.type == "cuda" and not self.bnb_quantized: | |
return self._quantize(device) | |
new = ForgeParams4bit( | |
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), | |
requires_grad=self.requires_grad, | |
quant_state=copy_quant_state(self.quant_state, device), | |
compress_statistics=False, | |
blocksize=self.blocksize, | |
quant_type=self.quant_type, | |
quant_storage=self.quant_storage, | |
bnb_quantized=self.bnb_quantized, | |
module=self.module | |
) | |
self.module.quant_state = new.quant_state | |
self.data = new.data | |
self.quant_state = new.quant_state | |
return new | |
class ForgeLoader4Bit(torch.nn.Module): | |
def __init__(self, *, device, dtype, quant_type, **kwargs): | |
super().__init__() | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) | |
self.weight = None | |
self.quant_state = None | |
self.bias = None | |
self.quant_type = quant_type | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
qs_keys = {k[len(prefix + "weight."):] for k in state_dict if k.startswith(prefix + "weight.")} | |
if any("bitsandbytes" in k for k in qs_keys): | |
qs = {k: state_dict[prefix + "weight." + k] for k in qs_keys} | |
self.weight = ForgeParams4bit.from_prequantized( | |
data=state_dict[prefix + "weight"], | |
quantized_stats=qs, | |
requires_grad=False, | |
device=torch.device('cuda'), | |
module=self | |
) | |
self.quant_state = self.weight.quant_state | |
if prefix + "bias" in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + "bias"].to(self.dummy)) | |
del self.dummy | |
else: | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs) | |
class Linear(ForgeLoader4Bit): | |
def __init__(self, *args, device=None, dtype=None, **kwargs): | |
super().__init__(device=device, dtype=dtype, quant_type='nf4') | |
def forward(self, x): | |
self.weight.quant_state = self.quant_state | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
self.bias.data = self.bias.data.to(x.dtype) | |
return functional_linear_4bits(x, self.weight, self.bias) | |
nn.Linear = Linear | |
# ---------------- Flux ๋ชจ๋ธ ์ ์ (๋ณ๊ฒฝ ์์) ---------------- | |
# (Attention, RoPE, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, | |
# SelfAttention, Modulation, DoubleStreamBlock, SingleStreamBlock, LastLayer, FluxParams, Flux ํด๋์ค) | |
# (์ฌ๊ธฐ์๋ ๊ธธ์ด์ ์๋ตํ์ง๋ง, ๊ธฐ์กด ์ฝ๋์ ์์ ํ ๋์ผํฉ๋๋ค.) | |
# ---------------- ๋ชจ๋ธ ๋ก๋ ---------------- | |
sd = load_file( | |
hf_hub_download( | |
repo_id="lllyasviel/flux1-dev-bnb-nf4", | |
filename="flux1-dev-bnb-nf4-v2.safetensors" | |
) | |
) | |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k} | |
model = Flux().to(torch_device, dtype=torch.bfloat16) | |
model.load_state_dict(sd) | |
model_zero_init = False | |
# ---------------- ์ด๋ฏธ์ง ์์ฑ ํจ์ ---------------- | |
def get_image(image) -> torch.Tensor | None: | |
if image is None: | |
return None | |
image = Image.fromarray(image).convert("RGB") | |
tf = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: 2.0 * x - 1.0), | |
]) | |
return tf(image)[None, ...] | |
def prepare(t5, clip, img, prompt): | |
bs, c, h, w = img.shape | |
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
if bs == 1 and isinstance(prompt, list): | |
img = repeat(img, "1 ... -> bs ...", bs=len(prompt)) | |
img_ids = torch.zeros(h//2, w//2, 3, device=img.device) | |
img_ids[...,1] = torch.arange(h//2, device=img.device)[:,None] | |
img_ids[...,2] = torch.arange(w//2, device=img.device)[None,:] | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0]) | |
txt = t5([prompt] if isinstance(prompt, str) else prompt) | |
if txt.shape[0] == 1 and img.shape[0] > 1: | |
txt = repeat(txt, "1 ... -> bs ...", bs=img.shape[0]) | |
txt_ids = torch.zeros(txt.size(0), txt.size(1), 3, device=img.device) | |
vec = clip([prompt] if isinstance(prompt, str) else prompt) | |
if vec.shape[0] == 1 and img.shape[0] > 1: | |
vec = repeat(vec, "1 ... -> bs ...", bs=img.shape[0]) | |
return { | |
"img": img, | |
"img_ids": img_ids, | |
"txt": txt, | |
"txt_ids": txt_ids, | |
"vec": vec, | |
} | |
def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): | |
timesteps = torch.linspace(1, 0, num_steps+1) | |
if shift: | |
mu = ((max_shift-base_shift)/(4096-256))*(image_seq_len) + (base_shift - (256*(max_shift-base_shift)/(4096-256))) | |
timesteps = timesteps.exp().div((1/timesteps-1)**1 + mu) | |
return timesteps.tolist() | |
def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): | |
guidance_vec = torch.full((img.size(0),), guidance, device=img.device, dtype=img.dtype) | |
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps)-1): | |
t_vec = torch.full((img.size(0),), t_curr, device=img.device, dtype=img.dtype) | |
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, | |
y=vec, timesteps=t_vec, guidance=guidance_vec) | |
img = img + (t_prev - t_curr) * pred | |
return img | |
def generate_image( | |
prompt, width, height, guidance, inference_steps, seed, | |
do_img2img, init_image, image2image_strength, resize_img, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
# ํ๊ธ ๊ฐ์ง ์ CPU ๋ฒ์ญ๊ธฐ ์ฌ์ฉ | |
if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt): | |
translated = translator(prompt, max_length=512)[0]['translation_text'] | |
prompt = translated | |
# ๋๋ค ์๋ | |
if seed == 0: | |
seed = random.randint(1, 1_000_000) | |
global model_zero_init, model | |
if not model_zero_init: | |
model = model.to(torch_device) | |
model_zero_init = True | |
# img2img ์ค๋น | |
if do_img2img and init_image is not None: | |
init_img = get_image(init_image) | |
if resize_img: | |
init_img = torch.nn.functional.interpolate(init_img, (height, width)) | |
else: | |
h0, w0 = init_img.shape[-2:] | |
init_img = init_img[..., :16*(h0//16), :16*(w0//16)] | |
height, width = init_img.shape[-2:] | |
init_img = ae.encode(init_img.to(torch_device).to(torch.bfloat16)).latent_dist.sample() | |
init_img = (init_img - ae.config.shift_factor) * ae.config.scaling_factor | |
else: | |
init_img = None | |
# ๋ ธ์ด์ฆ ์ํ ์์ฑ | |
generator = torch.Generator(device=str(torch_device)).manual_seed(seed) | |
x = torch.randn( | |
1, 16, 2*math.ceil(height/16), 2*math.ceil(width/16), | |
device=torch_device, dtype=torch.bfloat16, generator=generator | |
) | |
timesteps = get_schedule(inference_steps, (x.shape[-1]*x.shape[-2])//4, shift=True) | |
if do_img2img and init_img is not None: | |
t_idx = int((1 - image2image_strength) * inference_steps) | |
t = timesteps[t_idx] | |
timesteps = timesteps[t_idx:] | |
x = t * x + (1 - t) * init_img.to(x.dtype) | |
inp = prepare(t5, clip, x, prompt) | |
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) | |
x = rearrange(x[:, inp["txt"].shape[1]:, ...].float(), "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
h=math.ceil(height/16), w=math.ceil(width/16), ph=2, pw=2) | |
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): | |
x = (x / ae.config.scaling_factor) + ae.config.shift_factor | |
x = ae.decode(x).sample | |
x = x.clamp(-1,1) | |
img = Image.fromarray((127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)).cpu().byte().numpy()) | |
return img, seed | |
# ---------------- Gradio ๋ฐ๋ชจ ---------------- | |
css = """footer { visibility: hidden; }""" | |
def create_demo(): | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
gr.Markdown("# News! Multilingual version [https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual](https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt(ํ๊ธ ๊ฐ๋ฅ)", value="A cute and fluffy golden retriever puppy sitting upright...") | |
width = gr.Slider(128,2048,64,label="Width",value=768) | |
height= gr.Slider(128,2048,64,label="Height",value=768) | |
guidance = gr.Slider(1.0,5.0,0.1,label="Guidance",value=3.5) | |
steps = gr.Slider(1,30,1,label="Inference steps",value=30) | |
seed = gr.Number(label="Seed",precision=0) | |
do_i2i = gr.Checkbox(label="Image to Image",value=False) | |
init_img = gr.Image(label="Input Image", visible=False) | |
strength = gr.Slider(0.0,1.0,0.01,label="Noising strength",value=0.8,visible=False) | |
resize = gr.Checkbox(label="Resize image",value=True,visible=False) | |
btn = gr.Button("Generate") | |
with gr.Column(): | |
out_img = gr.Image(label="Generated Image") | |
out_seed = gr.Text(label="Used Seed") | |
do_i2i.change( | |
fn=lambda x: [gr.update(visible=x)]*3, | |
inputs=[do_i2i], | |
outputs=[init_img, strength, resize] | |
) | |
btn.click( | |
fn=generate_image, | |
inputs=[prompt, width, height, guidance, steps, seed, do_i2i, init_img, strength, resize], | |
outputs=[out_img, out_seed] | |
) | |
return demo | |
if __name__ == "__main__": | |
create_demo().launch() | |