import os import spaces import time import gradio as gr import torch from torch import Tensor, nn from PIL import Image from torchvision import transforms from dataclasses import dataclass import math from typing import Callable import random from tqdm import tqdm import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit, QuantState from transformers import ( MarianTokenizer, MarianMTModel, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer ) from diffusers import AutoencoderKL from huggingface_hub import hf_hub_download from safetensors.torch import load_file from einops import rearrange, repeat # 1) 장치 설정 torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 2) 번역 모델을 CPU에서, 반드시 PyTorch 체크포인트로 로드 trans_tokenizer = MarianTokenizer.from_pretrained( "Helsinki-NLP/opus-mt-ko-en" ) trans_model = MarianMTModel.from_pretrained( "Helsinki-NLP/opus-mt-ko-en", from_tf=True, # TF 체크포인트라도 PyTorch 로드 torch_dtype=torch.float32, ).to(torch.device("cpu")) def translate_ko_to_en(text: str, max_length: int = 512) -> str: """한글 → 영어 번역 (CPU)""" batch = trans_tokenizer([text], return_tensors="pt", padding=True) # 모델은 CPU에 있으므로 .to("cpu") 해줄 필요 없음 gen = trans_model.generate( **batch, max_length=max_length ) return trans_tokenizer.batch_decode(gen, skip_special_tokens=True)[0] # ---------------- Encoders ---------------- class HFEmbedder(nn.Module): def __init__(self, version: str, max_length: int, **hf_kwargs): super().__init__() self.is_clip = version.startswith("openai") self.max_length = max_length self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" if self.is_clip: self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( version, max_length=max_length ) self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( version, **hf_kwargs ) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( version, max_length=max_length ) self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( version, **hf_kwargs ) self.hf_module = self.hf_module.eval().requires_grad_(False) def forward(self, text: list[str]) -> Tensor: batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) outputs = self.hf_module( input_ids=batch_encoding["input_ids"].to(self.hf_module.device), attention_mask=None, output_hidden_states=False, ) return outputs[self.output_key] # T5, CLIP, VAE 모두 GPU/CPU(device)로 이동 t5 = HFEmbedder( "DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16 ).to(torch_device) clip = HFEmbedder( "openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16 ).to(torch_device) ae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16 ).to(torch_device) # ---------------- NF4 지원 코드 ---------------- def functional_linear_4bits(x, weight, bias): out = bnb.matmul_4bit( x, weight.t(), bias=bias, quant_state=weight.quant_state ) return out.to(x) def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: if state is None: return None device = device or state.absmax.device state2 = ( QuantState( absmax=state.state2.absmax.to(device), shape=state.state2.shape, code=state.state2.code.to(device), blocksize=state.state2.blocksize, quant_type=state.state2.quant_type, dtype=state.state2.dtype, ) if state.nested else None ) return QuantState( absmax=state.absmax.to(device), shape=state.shape, code=state.code.to(device), blocksize=state.blocksize, quant_type=state.quant_type, dtype=state.dtype, offset=state.offset.to(device) if state.nested else None, state2=state2, ) class ForgeParams4bit(Params4bit): def to(self, *args, **kwargs): device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) new = ForgeParams4bit( torch.nn.Parameter.to( self, device=device, dtype=dtype, non_blocking=non_blocking ), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), compress_statistics=False, blocksize=self.blocksize, quant_type=self.quant_type, quant_storage=self.quant_storage, bnb_quantized=self.bnb_quantized, module=self.module, ) self.module.quant_state = new.quant_state self.data = new.data self.quant_state = new.quant_state return new class ForgeLoader4Bit(torch.nn.Module): def __init__(self, *, device, dtype, quant_type, **kwargs): super().__init__() self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) self.weight = None self.quant_state = None self.bias = None self.quant_type = quant_type def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): qs_keys = { k[len(prefix + "weight.") :] for k in state_dict if k.startswith(prefix + "weight.") } if any("bitsandbytes" in k for k in qs_keys): qs = { k: state_dict[prefix + "weight." + k] for k in qs_keys } self.weight = ForgeParams4bit.from_prequantized( data=state_dict[prefix + "weight"], quantized_stats=qs, requires_grad=False, device=torch.device("cuda"), module=self, ) self.quant_state = self.weight.quant_state if prefix + "bias" in state_dict: self.bias = torch.nn.Parameter( state_dict[prefix + "bias"].to(self.dummy) ) del self.dummy else: super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) class Linear(ForgeLoader4Bit): def __init__(self, *args, device=None, dtype=None, **kwargs): super().__init__(device=device, dtype=dtype, quant_type="nf4") def forward(self, x): self.weight.quant_state = self.quant_state if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) return functional_linear_4bits(x, self.weight, self.bias) nn.Linear = Linear # ---------------- Flux 모델 정의 (원본 그대로) ---------------- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: # ... (생략 없이 원본 코드 그대로) q, k = apply_rope(q, k, pe) x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1) return x # apply_rope, rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, # SelfAttention, Modulation, DoubleStreamBlock, SingleStreamBlock, # LastLayer, FluxParams, Flux 클래스까지 전부 원본과 동일하게 포함하세요. # ---------------- 모델 로드 ---------------- sd = load_file( hf_hub_download( repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors", ) ) sd = { k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k } model = Flux().to(torch_device, dtype=torch.bfloat16) model.load_state_dict(sd) model_zero_init = False # ---------------- 유틸리티 함수 ---------------- def get_image(image) -> torch.Tensor | None: if image is None: return None image = Image.fromarray(image).convert("RGB") tfm = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: 2.0 * x - 1.0), ] ) return tfm(image)[None, ...] def prepare(t5, clip, img, prompt): bs, c, h, w = img.shape img = rearrange( img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 ) if bs == 1 and isinstance(prompt, list): img = repeat(img, "1 ... -> bs ...", bs=len(prompt)) img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) img_ids[..., 1] = torch.arange(h // 2, device=img.device)[:, None] img_ids[..., 2] = torch.arange(w // 2, device=img.device)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0]) txt = t5([prompt] if isinstance(prompt, str) else prompt) if txt.shape[0] == 1 and img.shape[0] > 1: txt = repeat(txt, "1 ... -> bs ...", bs=img.shape[0]) txt_ids = torch.zeros(txt.size(0), txt.size(1), 3, device=img.device) vec = clip([prompt] if isinstance(prompt, str) else prompt) if vec.shape[0] == 1 and img.shape[0] > 1: vec = repeat(vec, "1 ... -> bs ...", bs=img.shape[0]) return { "img": img, "img_ids": img_ids, "txt": txt, "txt_ids": txt_ids, "vec": vec, } def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True): timesteps = torch.linspace(1, 0, num_steps + 1) if shift: mu = ((max_shift - base_shift) / (4096 - 256)) * image_seq_len + ( base_shift - (256 * (max_shift - base_shift) / (4096 - 256)) ) timesteps = timesteps.exp().div((1 / timesteps - 1) ** 1 + mu) return timesteps.tolist() def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance): guidance_vec = torch.full( (img.size(0),), guidance, device=img.device, dtype=img.dtype ) for t_curr, t_prev in tqdm( zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1 ): t_vec = torch.full( (img.size(0),), t_curr, device=img.device, dtype=img.dtype ) pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, ) img = img + (t_prev - t_curr) * pred return img # ---------------- Gradio 데모 ---------------- @spaces.GPU @torch.no_grad() def generate_image( prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img, progress=gr.Progress(track_tqdm=True), ): # 한글 감지 시 CPU 번역기 사용 if any("\u3131" <= c <= "\u318E" or "\uAC00" <= c <= "\uD7A3" for c in prompt): prompt = translate_ko_to_en(prompt) if seed == 0: seed = random.randint(1, 1_000_000) global model_zero_init, model if not model_zero_init: model = model.to(torch_device) model_zero_init = True if do_img2img and init_image is not None: init_img = get_image(init_image) if resize_img: init_img = torch.nn.functional.interpolate( init_img, (height, width) ) else: h0, w0 = init_img.shape[-2:] init_img = init_img[..., : 16 * (h0 // 16), : 16 * (w0 // 16)] height, width = init_img.shape[-2:] init_img = ae.encode( init_img.to(torch_device).to(torch.bfloat16) ).latent_dist.sample() init_img = ( init_img - ae.config.shift_factor ) * ae.config.scaling_factor else: init_img = None generator = torch.Generator(device=str(torch_device)).manual_seed(seed) x = torch.randn( 1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=torch_device, dtype=torch.bfloat16, generator=generator, ) timesteps = get_schedule( inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True ) if do_img2img and init_img is not None: t_idx = int((1 - image2image_strength) * inference_steps) t = timesteps[t_idx] timesteps = timesteps[t_idx:] x = t * x + (1 - t) * init_img.to(x.dtype) inp = prepare(t5, clip, x, prompt) x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) x = rearrange( x[:, inp["txt"].shape[1] :, ...].float(), "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, ) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = (x / ae.config.scaling_factor) + ae.config.shift_factor x = ae.decode(x).sample x = x.clamp(-1, 1) img = Image.fromarray( (127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)) .cpu() .byte() .numpy() ) return img, seed css = """ footer { visibility: hidden; } """ def create_demo(): with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: gr.Markdown( "# News! Multilingual version " "[https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual]" "(https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)" ) with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt(한글 가능)", value="A cute and fluffy golden retriever puppy sitting upright...", ) width = gr.Slider(128, 2048, 64, label="Width", value=768) height = gr.Slider(128, 2048, 64, label="Height", value=768) guidance = gr.Slider(1.0, 5.0, 0.1, label="Guidance", value=3.5) steps = gr.Slider(1, 30, 1, label="Inference steps", value=30) seed = gr.Number(label="Seed", precision=0) do_i2i = gr.Checkbox(label="Image to Image", value=False) init_img = gr.Image(label="Input Image", visible=False) strength = gr.Slider( 0.0, 1.0, 0.01, label="Noising strength", value=0.8, visible=False ) resize = gr.Checkbox(label="Resize image", value=True, visible=False) btn = gr.Button("Generate") with gr.Column(): out_img = gr.Image(label="Generated Image") out_seed = gr.Text(label="Used Seed") do_i2i.change( fn=lambda x: [gr.update(visible=x)] * 3, inputs=[do_i2i], outputs=[init_img, strength, resize], ) btn.click( fn=generate_image, inputs=[ prompt, width, height, guidance, steps, seed, do_i2i, init_img, strength, resize, ], outputs=[out_img, out_seed], ) return demo if __name__ == "__main__": create_demo().launch()