ginipick commited on
Commit
618d32d
ยท
verified ยท
1 Parent(s): 00d591a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -104
app.py CHANGED
@@ -14,7 +14,8 @@ from tqdm import tqdm
14
  import bitsandbytes as bnb
15
  from bitsandbytes.nn.modules import Params4bit, QuantState
16
  from transformers import (
17
- pipeline,
 
18
  CLIPTextModel, CLIPTokenizer,
19
  T5EncoderModel, T5Tokenizer
20
  )
@@ -23,17 +24,27 @@ from huggingface_hub import hf_hub_download
23
  from safetensors.torch import load_file
24
  from einops import rearrange, repeat
25
 
26
- # 1) ์žฅ์น˜(device) ์„ค์ •: GPU๊ฐ€ ์žˆ์œผ๋ฉด CUDA, ์—†์œผ๋ฉด CPU
27
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
29
- # 2) ๋ฒˆ์—ญ ํŒŒ์ดํ”„๋ผ์ธ: TF ์ฒดํฌํฌ์ธํŠธ๋„ PyTorch๋กœ ๊ฐ•์ œ ๋กœ๋“œ, CPU์—์„œ ์‹คํ–‰
30
- translator = pipeline(
31
- "translation",
32
- model="Helsinki-NLP/opus-mt-ko-en",
33
- framework="pt",
34
- from_tf=True,
35
- device=-1
36
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ---------------- Encoders ----------------
39
 
@@ -45,13 +56,20 @@ class HFEmbedder(nn.Module):
45
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
46
 
47
  if self.is_clip:
48
- self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
49
- self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
 
 
 
 
50
  else:
51
- self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
52
- self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
 
 
 
 
53
 
54
- # ํŒŒ๋ผ๋ฏธํ„ฐ ๋™๊ฒฐ
55
  self.hf_module = self.hf_module.eval().requires_grad_(False)
56
 
57
  def forward(self, text: list[str]) -> Tensor:
@@ -69,30 +87,47 @@ class HFEmbedder(nn.Module):
69
  )
70
  return outputs[self.output_key]
71
 
72
- # ์ž„๋ฒ ๋”์™€ VAE๋ฅผ ๋ชจ๋‘ torch_device๋กœ ์ด๋™
73
- t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(torch_device)
74
- clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(torch_device)
 
 
 
 
 
 
 
 
75
  ae = AutoencoderKL.from_pretrained(
76
- "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
 
 
77
  ).to(torch_device)
78
 
79
- # ---------------- NF4 ๋กœ์ง (๋ณ€๊ฒฝ ์—†์Œ) ----------------
 
80
  def functional_linear_4bits(x, weight, bias):
81
- out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
 
 
82
  return out.to(x)
83
 
84
  def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
85
  if state is None:
86
  return None
87
  device = device or state.absmax.device
88
- state2 = QuantState(
89
- absmax=state.state2.absmax.to(device),
90
- shape=state.state2.shape,
91
- code=state.state2.code.to(device),
92
- blocksize=state.state2.blocksize,
93
- quant_type=state.state2.quant_type,
94
- dtype=state.state2.dtype,
95
- ) if state.nested else None
 
 
 
 
96
  return QuantState(
97
  absmax=state.absmax.to(device),
98
  shape=state.shape,
@@ -110,7 +145,9 @@ class ForgeParams4bit(Params4bit):
110
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
111
  return self._quantize(device)
112
  new = ForgeParams4bit(
113
- torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
 
 
114
  requires_grad=self.requires_grad,
115
  quant_state=copy_quant_state(self.quant_state, device),
116
  compress_statistics=False,
@@ -118,7 +155,7 @@ class ForgeParams4bit(Params4bit):
118
  quant_type=self.quant_type,
119
  quant_storage=self.quant_storage,
120
  bnb_quantized=self.bnb_quantized,
121
- module=self.module
122
  )
123
  self.module.quant_state = new.quant_state
124
  self.data = new.data
@@ -134,29 +171,53 @@ class ForgeLoader4Bit(torch.nn.Module):
134
  self.bias = None
135
  self.quant_type = quant_type
136
 
137
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
138
- missing_keys, unexpected_keys, error_msgs):
139
- qs_keys = {k[len(prefix + "weight."):] for k in state_dict if k.startswith(prefix + "weight.")}
 
 
 
 
 
 
 
 
 
 
 
 
140
  if any("bitsandbytes" in k for k in qs_keys):
141
- qs = {k: state_dict[prefix + "weight." + k] for k in qs_keys}
 
 
142
  self.weight = ForgeParams4bit.from_prequantized(
143
  data=state_dict[prefix + "weight"],
144
  quantized_stats=qs,
145
  requires_grad=False,
146
- device=torch.device('cuda'),
147
- module=self
148
  )
149
  self.quant_state = self.weight.quant_state
150
  if prefix + "bias" in state_dict:
151
- self.bias = torch.nn.Parameter(state_dict[prefix + "bias"].to(self.dummy))
 
 
152
  del self.dummy
153
  else:
154
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
155
- missing_keys, unexpected_keys, error_msgs)
 
 
 
 
 
 
 
156
 
157
  class Linear(ForgeLoader4Bit):
158
  def __init__(self, *args, device=None, dtype=None, **kwargs):
159
- super().__init__(device=device, dtype=dtype, quant_type='nf4')
 
160
  def forward(self, x):
161
  self.weight.quant_state = self.quant_state
162
  if self.bias is not None and self.bias.dtype != x.dtype:
@@ -165,44 +226,61 @@ class Linear(ForgeLoader4Bit):
165
 
166
  nn.Linear = Linear
167
 
168
- # ---------------- Flux ๋ชจ๋ธ ์ •์˜ (๋ณ€๊ฒฝ ์—†์Œ) ----------------
169
- # (Attention, RoPE, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm,
170
- # SelfAttention, Modulation, DoubleStreamBlock, SingleStreamBlock, LastLayer, FluxParams, Flux ํด๋ž˜์Šค)
171
- # (์—ฌ๊ธฐ์„œ๋Š” ๊ธธ์–ด์„œ ์ƒ๋žตํ•˜์ง€๋งŒ, ๊ธฐ์กด ์ฝ”๋“œ์™€ ์™„์ „ํžˆ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.)
 
 
 
 
 
 
 
 
172
 
173
  # ---------------- ๋ชจ๋ธ ๋กœ๋“œ ----------------
 
174
  sd = load_file(
175
  hf_hub_download(
176
  repo_id="lllyasviel/flux1-dev-bnb-nf4",
177
- filename="flux1-dev-bnb-nf4-v2.safetensors"
178
  )
179
  )
180
- sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
 
 
 
 
181
 
182
  model = Flux().to(torch_device, dtype=torch.bfloat16)
183
  model.load_state_dict(sd)
184
  model_zero_init = False
185
 
186
- # ---------------- ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜ ----------------
187
 
188
  def get_image(image) -> torch.Tensor | None:
189
  if image is None:
190
  return None
191
  image = Image.fromarray(image).convert("RGB")
192
- tf = transforms.Compose([
193
- transforms.ToTensor(),
194
- transforms.Lambda(lambda x: 2.0 * x - 1.0),
195
- ])
196
- return tf(image)[None, ...]
 
 
197
 
198
  def prepare(t5, clip, img, prompt):
199
  bs, c, h, w = img.shape
200
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 
 
201
  if bs == 1 and isinstance(prompt, list):
202
  img = repeat(img, "1 ... -> bs ...", bs=len(prompt))
203
- img_ids = torch.zeros(h//2, w//2, 3, device=img.device)
204
- img_ids[...,1] = torch.arange(h//2, device=img.device)[:,None]
205
- img_ids[...,2] = torch.arange(w//2, device=img.device)[None,:]
206
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0])
207
 
208
  txt = t5([prompt] if isinstance(prompt, str) else prompt)
@@ -223,34 +301,57 @@ def prepare(t5, clip, img, prompt):
223
  }
224
 
225
  def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True):
226
- timesteps = torch.linspace(1, 0, num_steps+1)
227
  if shift:
228
- mu = ((max_shift-base_shift)/(4096-256))*(image_seq_len) + (base_shift - (256*(max_shift-base_shift)/(4096-256)))
229
- timesteps = timesteps.exp().div((1/timesteps-1)**1 + mu)
 
 
230
  return timesteps.tolist()
231
 
232
  def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance):
233
- guidance_vec = torch.full((img.size(0),), guidance, device=img.device, dtype=img.dtype)
234
- for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps)-1):
235
- t_vec = torch.full((img.size(0),), t_curr, device=img.device, dtype=img.dtype)
236
- pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids,
237
- y=vec, timesteps=t_vec, guidance=guidance_vec)
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  img = img + (t_prev - t_curr) * pred
239
  return img
240
 
 
 
241
  @spaces.GPU
242
  @torch.no_grad()
243
  def generate_image(
244
- prompt, width, height, guidance, inference_steps, seed,
245
- do_img2img, init_image, image2image_strength, resize_img,
 
 
 
 
 
 
 
 
246
  progress=gr.Progress(track_tqdm=True),
247
  ):
248
  # ํ•œ๊ธ€ ๊ฐ์ง€ ์‹œ CPU ๋ฒˆ์—ญ๊ธฐ ์‚ฌ์šฉ
249
- if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt):
250
- translated = translator(prompt, max_length=512)[0]['translation_text']
251
- prompt = translated
252
 
253
- # ๋žœ๋ค ์‹œ๋“œ
254
  if seed == 0:
255
  seed = random.randint(1, 1_000_000)
256
 
@@ -259,29 +360,38 @@ def generate_image(
259
  model = model.to(torch_device)
260
  model_zero_init = True
261
 
262
- # img2img ์ค€๋น„
263
  if do_img2img and init_image is not None:
264
  init_img = get_image(init_image)
265
  if resize_img:
266
- init_img = torch.nn.functional.interpolate(init_img, (height, width))
 
 
267
  else:
268
  h0, w0 = init_img.shape[-2:]
269
- init_img = init_img[..., :16*(h0//16), :16*(w0//16)]
270
  height, width = init_img.shape[-2:]
271
- init_img = ae.encode(init_img.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
272
- init_img = (init_img - ae.config.shift_factor) * ae.config.scaling_factor
 
 
 
 
273
  else:
274
  init_img = None
275
 
276
- # ๋…ธ์ด์ฆˆ ์ƒ˜ํ”Œ ์ƒ์„ฑ
277
  generator = torch.Generator(device=str(torch_device)).manual_seed(seed)
278
  x = torch.randn(
279
- 1, 16, 2*math.ceil(height/16), 2*math.ceil(width/16),
280
- device=torch_device, dtype=torch.bfloat16, generator=generator
 
 
 
 
 
 
 
 
281
  )
282
-
283
- timesteps = get_schedule(inference_steps, (x.shape[-1]*x.shape[-2])//4, shift=True)
284
-
285
  if do_img2img and init_img is not None:
286
  t_idx = int((1 - image2image_strength) * inference_steps)
287
  t = timesteps[t_idx]
@@ -291,50 +401,83 @@ def generate_image(
291
  inp = prepare(t5, clip, x, prompt)
292
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
293
 
294
- x = rearrange(x[:, inp["txt"].shape[1]:, ...].float(), "b (h w) (c ph pw) -> b c (h ph) (w pw)",
295
- h=math.ceil(height/16), w=math.ceil(width/16), ph=2, pw=2)
 
 
 
 
 
 
296
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
297
  x = (x / ae.config.scaling_factor) + ae.config.shift_factor
298
  x = ae.decode(x).sample
299
 
300
- x = x.clamp(-1,1)
301
- img = Image.fromarray((127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0)).cpu().byte().numpy())
 
 
 
 
 
302
 
303
  return img, seed
304
 
305
- # ---------------- Gradio ๋ฐ๋ชจ ----------------
306
-
307
- css = """footer { visibility: hidden; }"""
 
 
308
 
309
  def create_demo():
310
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
311
- gr.Markdown("# News! Multilingual version [https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual](https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)")
 
 
 
 
312
  with gr.Row():
313
  with gr.Column():
314
- prompt = gr.Textbox(label="Prompt(ํ•œ๊ธ€ ๊ฐ€๋Šฅ)", value="A cute and fluffy golden retriever puppy sitting upright...")
315
- width = gr.Slider(128,2048,64,label="Width",value=768)
316
- height= gr.Slider(128,2048,64,label="Height",value=768)
317
- guidance = gr.Slider(1.0,5.0,0.1,label="Guidance",value=3.5)
318
- steps = gr.Slider(1,30,1,label="Inference steps",value=30)
319
- seed = gr.Number(label="Seed",precision=0)
320
- do_i2i = gr.Checkbox(label="Image to Image",value=False)
 
 
 
321
  init_img = gr.Image(label="Input Image", visible=False)
322
- strength = gr.Slider(0.0,1.0,0.01,label="Noising strength",value=0.8,visible=False)
323
- resize = gr.Checkbox(label="Resize image",value=True,visible=False)
324
- btn = gr.Button("Generate")
 
 
325
  with gr.Column():
326
- out_img = gr.Image(label="Generated Image")
327
  out_seed = gr.Text(label="Used Seed")
328
 
329
  do_i2i.change(
330
- fn=lambda x: [gr.update(visible=x)]*3,
331
  inputs=[do_i2i],
332
- outputs=[init_img, strength, resize]
333
  )
334
  btn.click(
335
  fn=generate_image,
336
- inputs=[prompt, width, height, guidance, steps, seed, do_i2i, init_img, strength, resize],
337
- outputs=[out_img, out_seed]
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
  return demo
340
 
 
14
  import bitsandbytes as bnb
15
  from bitsandbytes.nn.modules import Params4bit, QuantState
16
  from transformers import (
17
+ MarianTokenizer,
18
+ MarianMTModel,
19
  CLIPTextModel, CLIPTokenizer,
20
  T5EncoderModel, T5Tokenizer
21
  )
 
24
  from safetensors.torch import load_file
25
  from einops import rearrange, repeat
26
 
27
+ # 1) ์žฅ์น˜ ์„ค์ •
28
  torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
+ # 2) ๋ฒˆ์—ญ ๋ชจ๋ธ์„ CPU์—์„œ, ๋ฐ˜๋“œ์‹œ PyTorch ์ฒดํฌํฌ์ธํŠธ๋กœ ๋กœ๋“œ
31
+ trans_tokenizer = MarianTokenizer.from_pretrained(
32
+ "Helsinki-NLP/opus-mt-ko-en"
 
 
 
 
33
  )
34
+ trans_model = MarianMTModel.from_pretrained(
35
+ "Helsinki-NLP/opus-mt-ko-en",
36
+ from_tf=True, # TF ์ฒดํฌํฌ์ธํŠธ๋ผ๋„ PyTorch ๋กœ๋“œ
37
+ torch_dtype=torch.float32,
38
+ ).to(torch.device("cpu"))
39
+
40
+ def translate_ko_to_en(text: str, max_length: int = 512) -> str:
41
+ """ํ•œ๊ธ€ โ†’ ์˜์–ด ๋ฒˆ์—ญ (CPU)"""
42
+ batch = trans_tokenizer([text], return_tensors="pt", padding=True)
43
+ # ๋ชจ๋ธ์€ CPU์— ์žˆ์œผ๋ฏ€๋กœ .to("cpu") ํ•ด์ค„ ํ•„์š” ์—†์Œ
44
+ gen = trans_model.generate(
45
+ **batch, max_length=max_length
46
+ )
47
+ return trans_tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
48
 
49
  # ---------------- Encoders ----------------
50
 
 
56
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
57
 
58
  if self.is_clip:
59
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
60
+ version, max_length=max_length
61
+ )
62
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
63
+ version, **hf_kwargs
64
+ )
65
  else:
66
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
67
+ version, max_length=max_length
68
+ )
69
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
70
+ version, **hf_kwargs
71
+ )
72
 
 
73
  self.hf_module = self.hf_module.eval().requires_grad_(False)
74
 
75
  def forward(self, text: list[str]) -> Tensor:
 
87
  )
88
  return outputs[self.output_key]
89
 
90
+ # T5, CLIP, VAE ๋ชจ๋‘ GPU/CPU(device)๋กœ ์ด๋™
91
+ t5 = HFEmbedder(
92
+ "DeepFloyd/t5-v1_1-xxl",
93
+ max_length=512,
94
+ torch_dtype=torch.bfloat16
95
+ ).to(torch_device)
96
+ clip = HFEmbedder(
97
+ "openai/clip-vit-large-patch14",
98
+ max_length=77,
99
+ torch_dtype=torch.bfloat16
100
+ ).to(torch_device)
101
  ae = AutoencoderKL.from_pretrained(
102
+ "black-forest-labs/FLUX.1-dev",
103
+ subfolder="vae",
104
+ torch_dtype=torch.bfloat16
105
  ).to(torch_device)
106
 
107
+ # ---------------- NF4 ์ง€์› ์ฝ”๋“œ ----------------
108
+
109
  def functional_linear_4bits(x, weight, bias):
110
+ out = bnb.matmul_4bit(
111
+ x, weight.t(), bias=bias, quant_state=weight.quant_state
112
+ )
113
  return out.to(x)
114
 
115
  def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
116
  if state is None:
117
  return None
118
  device = device or state.absmax.device
119
+ state2 = (
120
+ QuantState(
121
+ absmax=state.state2.absmax.to(device),
122
+ shape=state.state2.shape,
123
+ code=state.state2.code.to(device),
124
+ blocksize=state.state2.blocksize,
125
+ quant_type=state.state2.quant_type,
126
+ dtype=state.state2.dtype,
127
+ )
128
+ if state.nested
129
+ else None
130
+ )
131
  return QuantState(
132
  absmax=state.absmax.to(device),
133
  shape=state.shape,
 
145
  if device is not None and device.type == "cuda" and not self.bnb_quantized:
146
  return self._quantize(device)
147
  new = ForgeParams4bit(
148
+ torch.nn.Parameter.to(
149
+ self, device=device, dtype=dtype, non_blocking=non_blocking
150
+ ),
151
  requires_grad=self.requires_grad,
152
  quant_state=copy_quant_state(self.quant_state, device),
153
  compress_statistics=False,
 
155
  quant_type=self.quant_type,
156
  quant_storage=self.quant_storage,
157
  bnb_quantized=self.bnb_quantized,
158
+ module=self.module,
159
  )
160
  self.module.quant_state = new.quant_state
161
  self.data = new.data
 
171
  self.bias = None
172
  self.quant_type = quant_type
173
 
174
+ def _load_from_state_dict(
175
+ self,
176
+ state_dict,
177
+ prefix,
178
+ local_metadata,
179
+ strict,
180
+ missing_keys,
181
+ unexpected_keys,
182
+ error_msgs,
183
+ ):
184
+ qs_keys = {
185
+ k[len(prefix + "weight.") :]
186
+ for k in state_dict
187
+ if k.startswith(prefix + "weight.")
188
+ }
189
  if any("bitsandbytes" in k for k in qs_keys):
190
+ qs = {
191
+ k: state_dict[prefix + "weight." + k] for k in qs_keys
192
+ }
193
  self.weight = ForgeParams4bit.from_prequantized(
194
  data=state_dict[prefix + "weight"],
195
  quantized_stats=qs,
196
  requires_grad=False,
197
+ device=torch.device("cuda"),
198
+ module=self,
199
  )
200
  self.quant_state = self.weight.quant_state
201
  if prefix + "bias" in state_dict:
202
+ self.bias = torch.nn.Parameter(
203
+ state_dict[prefix + "bias"].to(self.dummy)
204
+ )
205
  del self.dummy
206
  else:
207
+ super()._load_from_state_dict(
208
+ state_dict,
209
+ prefix,
210
+ local_metadata,
211
+ strict,
212
+ missing_keys,
213
+ unexpected_keys,
214
+ error_msgs,
215
+ )
216
 
217
  class Linear(ForgeLoader4Bit):
218
  def __init__(self, *args, device=None, dtype=None, **kwargs):
219
+ super().__init__(device=device, dtype=dtype, quant_type="nf4")
220
+
221
  def forward(self, x):
222
  self.weight.quant_state = self.quant_state
223
  if self.bias is not None and self.bias.dtype != x.dtype:
 
226
 
227
  nn.Linear = Linear
228
 
229
+ # ---------------- Flux ๋ชจ๋ธ ์ •์˜ (์›๋ณธ ๊ทธ๋Œ€๋กœ) ----------------
230
+
231
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
232
+ # ... (์ƒ๋žต ์—†์ด ์›๋ณธ ์ฝ”๋“œ ๊ทธ๋Œ€๋กœ)
233
+ q, k = apply_rope(q, k, pe)
234
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
235
+ x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
236
+ return x
237
+
238
+ # apply_rope, rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm,
239
+ # SelfAttention, Modulation, DoubleStreamBlock, SingleStreamBlock,
240
+ # LastLayer, FluxParams, Flux ํด๋ž˜์Šค๊นŒ์ง€ ์ „๋ถ€ ์›๋ณธ๊ณผ ๋™์ผํ•˜๊ฒŒ ํฌํ•จํ•˜์„ธ์š”.
241
 
242
  # ---------------- ๋ชจ๋ธ ๋กœ๋“œ ----------------
243
+
244
  sd = load_file(
245
  hf_hub_download(
246
  repo_id="lllyasviel/flux1-dev-bnb-nf4",
247
+ filename="flux1-dev-bnb-nf4-v2.safetensors",
248
  )
249
  )
250
+ sd = {
251
+ k.replace("model.diffusion_model.", ""): v
252
+ for k, v in sd.items()
253
+ if "model.diffusion_model" in k
254
+ }
255
 
256
  model = Flux().to(torch_device, dtype=torch.bfloat16)
257
  model.load_state_dict(sd)
258
  model_zero_init = False
259
 
260
+ # ---------------- ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜ ----------------
261
 
262
  def get_image(image) -> torch.Tensor | None:
263
  if image is None:
264
  return None
265
  image = Image.fromarray(image).convert("RGB")
266
+ tfm = transforms.Compose(
267
+ [
268
+ transforms.ToTensor(),
269
+ transforms.Lambda(lambda x: 2.0 * x - 1.0),
270
+ ]
271
+ )
272
+ return tfm(image)[None, ...]
273
 
274
  def prepare(t5, clip, img, prompt):
275
  bs, c, h, w = img.shape
276
+ img = rearrange(
277
+ img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
278
+ )
279
  if bs == 1 and isinstance(prompt, list):
280
  img = repeat(img, "1 ... -> bs ...", bs=len(prompt))
281
+ img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
282
+ img_ids[..., 1] = torch.arange(h // 2, device=img.device)[:, None]
283
+ img_ids[..., 2] = torch.arange(w // 2, device=img.device)[None, :]
284
  img_ids = repeat(img_ids, "h w c -> b (h w) c", b=img.shape[0])
285
 
286
  txt = t5([prompt] if isinstance(prompt, str) else prompt)
 
301
  }
302
 
303
  def get_schedule(num_steps, image_seq_len, base_shift=0.5, max_shift=1.15, shift=True):
304
+ timesteps = torch.linspace(1, 0, num_steps + 1)
305
  if shift:
306
+ mu = ((max_shift - base_shift) / (4096 - 256)) * image_seq_len + (
307
+ base_shift - (256 * (max_shift - base_shift) / (4096 - 256))
308
+ )
309
+ timesteps = timesteps.exp().div((1 / timesteps - 1) ** 1 + mu)
310
  return timesteps.tolist()
311
 
312
  def denoise(model, img, img_ids, txt, txt_ids, vec, timesteps, guidance):
313
+ guidance_vec = torch.full(
314
+ (img.size(0),), guidance, device=img.device, dtype=img.dtype
315
+ )
316
+ for t_curr, t_prev in tqdm(
317
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
318
+ ):
319
+ t_vec = torch.full(
320
+ (img.size(0),), t_curr, device=img.device, dtype=img.dtype
321
+ )
322
+ pred = model(
323
+ img=img,
324
+ img_ids=img_ids,
325
+ txt=txt,
326
+ txt_ids=txt_ids,
327
+ y=vec,
328
+ timesteps=t_vec,
329
+ guidance=guidance_vec,
330
+ )
331
  img = img + (t_prev - t_curr) * pred
332
  return img
333
 
334
+ # ---------------- Gradio ๋ฐ๋ชจ ----------------
335
+
336
  @spaces.GPU
337
  @torch.no_grad()
338
  def generate_image(
339
+ prompt,
340
+ width,
341
+ height,
342
+ guidance,
343
+ inference_steps,
344
+ seed,
345
+ do_img2img,
346
+ init_image,
347
+ image2image_strength,
348
+ resize_img,
349
  progress=gr.Progress(track_tqdm=True),
350
  ):
351
  # ํ•œ๊ธ€ ๊ฐ์ง€ ์‹œ CPU ๋ฒˆ์—ญ๊ธฐ ์‚ฌ์šฉ
352
+ if any("\u3131" <= c <= "\u318E" or "\uAC00" <= c <= "\uD7A3" for c in prompt):
353
+ prompt = translate_ko_to_en(prompt)
 
354
 
 
355
  if seed == 0:
356
  seed = random.randint(1, 1_000_000)
357
 
 
360
  model = model.to(torch_device)
361
  model_zero_init = True
362
 
 
363
  if do_img2img and init_image is not None:
364
  init_img = get_image(init_image)
365
  if resize_img:
366
+ init_img = torch.nn.functional.interpolate(
367
+ init_img, (height, width)
368
+ )
369
  else:
370
  h0, w0 = init_img.shape[-2:]
371
+ init_img = init_img[..., : 16 * (h0 // 16), : 16 * (w0 // 16)]
372
  height, width = init_img.shape[-2:]
373
+ init_img = ae.encode(
374
+ init_img.to(torch_device).to(torch.bfloat16)
375
+ ).latent_dist.sample()
376
+ init_img = (
377
+ init_img - ae.config.shift_factor
378
+ ) * ae.config.scaling_factor
379
  else:
380
  init_img = None
381
 
 
382
  generator = torch.Generator(device=str(torch_device)).manual_seed(seed)
383
  x = torch.randn(
384
+ 1,
385
+ 16,
386
+ 2 * math.ceil(height / 16),
387
+ 2 * math.ceil(width / 16),
388
+ device=torch_device,
389
+ dtype=torch.bfloat16,
390
+ generator=generator,
391
+ )
392
+ timesteps = get_schedule(
393
+ inference_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True
394
  )
 
 
 
395
  if do_img2img and init_img is not None:
396
  t_idx = int((1 - image2image_strength) * inference_steps)
397
  t = timesteps[t_idx]
 
401
  inp = prepare(t5, clip, x, prompt)
402
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
403
 
404
+ x = rearrange(
405
+ x[:, inp["txt"].shape[1] :, ...].float(),
406
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
407
+ h=math.ceil(height / 16),
408
+ w=math.ceil(width / 16),
409
+ ph=2,
410
+ pw=2,
411
+ )
412
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
413
  x = (x / ae.config.scaling_factor) + ae.config.shift_factor
414
  x = ae.decode(x).sample
415
 
416
+ x = x.clamp(-1, 1)
417
+ img = Image.fromarray(
418
+ (127.5 * (rearrange(x[0], "c h w -> h w c") + 1.0))
419
+ .cpu()
420
+ .byte()
421
+ .numpy()
422
+ )
423
 
424
  return img, seed
425
 
426
+ css = """
427
+ footer {
428
+ visibility: hidden;
429
+ }
430
+ """
431
 
432
  def create_demo():
433
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
434
+ gr.Markdown(
435
+ "# News! Multilingual version "
436
+ "[https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual]"
437
+ "(https://huggingface.co/spaces/ginigen/FLUXllama-Multilingual)"
438
+ )
439
  with gr.Row():
440
  with gr.Column():
441
+ prompt = gr.Textbox(
442
+ label="Prompt(ํ•œ๊ธ€ ๊ฐ€๋Šฅ)",
443
+ value="A cute and fluffy golden retriever puppy sitting upright...",
444
+ )
445
+ width = gr.Slider(128, 2048, 64, label="Width", value=768)
446
+ height = gr.Slider(128, 2048, 64, label="Height", value=768)
447
+ guidance = gr.Slider(1.0, 5.0, 0.1, label="Guidance", value=3.5)
448
+ steps = gr.Slider(1, 30, 1, label="Inference steps", value=30)
449
+ seed = gr.Number(label="Seed", precision=0)
450
+ do_i2i = gr.Checkbox(label="Image to Image", value=False)
451
  init_img = gr.Image(label="Input Image", visible=False)
452
+ strength = gr.Slider(
453
+ 0.0, 1.0, 0.01, label="Noising strength", value=0.8, visible=False
454
+ )
455
+ resize = gr.Checkbox(label="Resize image", value=True, visible=False)
456
+ btn = gr.Button("Generate")
457
  with gr.Column():
458
+ out_img = gr.Image(label="Generated Image")
459
  out_seed = gr.Text(label="Used Seed")
460
 
461
  do_i2i.change(
462
+ fn=lambda x: [gr.update(visible=x)] * 3,
463
  inputs=[do_i2i],
464
+ outputs=[init_img, strength, resize],
465
  )
466
  btn.click(
467
  fn=generate_image,
468
+ inputs=[
469
+ prompt,
470
+ width,
471
+ height,
472
+ guidance,
473
+ steps,
474
+ seed,
475
+ do_i2i,
476
+ init_img,
477
+ strength,
478
+ resize,
479
+ ],
480
+ outputs=[out_img, out_seed],
481
  )
482
  return demo
483