eienmojiki commited on
Commit
ac6932a
·
verified ·
1 Parent(s): 35dc249

Create utils/t2i.py

Browse files
Files changed (1) hide show
  1. utils/t2i.py +67 -0
utils/t2i.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import datetime
6
+ from diffusers import FluxPipeline
7
+ from lib_layerdiffuse.pipeline_flux_img2img import FluxImg2ImgPipeline
8
+ from lib_layerdiffuse.vae import TransparentVAE, pad_rgb
9
+ import numpy as np
10
+ from torchvision import transforms
11
+ from safetensors.torch import load_file
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import spaces
14
+
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+
18
+ def seed_everything(seed: int) -> torch.Generator:
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+ np.random.seed(seed)
22
+ generator = torch.Generator()
23
+ generator.manual_seed(seed)
24
+ return generator
25
+
26
+ t2i_pipe = FluxPipeline.from_pretrained(
27
+ "black-forest-labs/FLUX.1-dev",
28
+ torch_dtype=torch.bfloat16,
29
+ use_auth_token=HF_TOKEN
30
+ ).to(device)
31
+
32
+ trans_vae = TransparentVAE(t2i_pipe.vae, t2i_pipe.vae.dtype)
33
+ trans_vae.load_state_dict(torch.load("./models/TransparentVAE.pth"), strict=False)
34
+ trans_vae.to(device)
35
+
36
+ @spaces.GPU(duration=75)
37
+ def t2i_gen(
38
+ prompt: str,
39
+ # negative_prompt: str = None,
40
+ seed: int = 1111,
41
+ width: int = 1024,
42
+ height: int = 1024,
43
+ guidance_scale: float = 3.5,
44
+ num_inference_steps: int = 50,
45
+ ):
46
+ t2i_pipe.load_lora_weights("RedAIGC/Flux-version-LayerDiffuse", weight_name="layerlora.safetensors")
47
+ latents = t2i_pipe(
48
+ prompt=prompt,
49
+ height=height,
50
+ width=width,
51
+ num_inference_steps=num_inference_steps,
52
+ output_type="latent",
53
+ generator=seed_everything(seed),
54
+ guidance_scale=guidance_scale,
55
+ ).images
56
+
57
+ latents = t2i_pipe._unpack_latents(latents, height, width, t2i_pipe.vae_scale_factor)
58
+ latents = (latents / t2i_pipe.vae.config.scaling_factor) + t2i_pipe.vae.config.shift_factor
59
+
60
+ with torch.no_grad():
61
+ original_x, x = trans_vae.decode(latents)
62
+
63
+ x = x.clamp(0, 1)
64
+ x = x.permute(0, 2, 3, 1)
65
+ img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
66
+
67
+ return img