Spaces:
Yuanshi
/
Running on Zero

Yuanshi commited on
Commit
2b4e74b
·
verified ·
1 Parent(s): 0f92491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -44
app.py CHANGED
@@ -1,56 +1,44 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import os
5
 
6
  import spaces
7
  from pipeline_flux import FluxPipeline
8
  from transformer_flux import FluxTransformer2DModel
9
  import torch
10
 
11
- flux_model = "schnell"
12
  bfl_repo = f"black-forest-labs/FLUX.1-{flux_model}"
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  dtype = torch.bfloat16
16
 
17
  transformer = FluxTransformer2DModel.from_pretrained(
18
- bfl_repo, subfolder="transformer", torch_dtype=dtype
 
19
  )
20
  pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
21
  pipe.transformer = transformer
22
  pipe.scheduler.config.use_dynamic_shifting = False
23
  pipe.scheduler.config.time_shift = 10
24
- pipe.enable_model_cpu_offload()
25
  pipe = pipe.to(device)
26
 
27
-
28
- transformer2 = FluxTransformer2DModel.from_pretrained(
29
- "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=dtype,
30
- use_auth_token=os.getenv("HF_TOKEN"),
31
- )
32
- pipe2 = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
33
- pipe2.transformer = transformer2
34
- pipe2.scheduler.config.use_dynamic_shifting = False
35
- pipe2.scheduler.config.time_shift = 10
36
- pipe2.enable_model_cpu_offload()
37
- pipe2 = pipe.to(device)
38
-
39
  pipe.load_lora_weights(
40
  "Huage001/URAE",
41
  weight_name="urae_2k_adapter.safetensors",
42
  adapter_name="2k",
43
  )
44
- # pipe.load_lora_weights(
45
- # "Huage001/URAE",
46
- # weight_name="urae_4k_adapter_lora_conversion_dev.safetensors",
47
- # adapter_name="4k_dev",
48
- # )
49
- # pipe.load_lora_weights(
50
- # "Huage001/URAE",
51
- # weight_name="urae_4k_adapter_lora_conversion_schnell.safetensors",
52
- # adapter_name="4k_schnell",
53
- # )
54
  MAX_SEED = np.iinfo(np.int32).max
55
  MAX_IMAGE_SIZE = 4096
56
  USE_ZERO_GPU = True
@@ -64,17 +52,16 @@ def infer(
64
  width,
65
  height,
66
  num_inference_steps,
67
- model='2k',
68
  progress=gr.Progress(track_tqdm=True),
 
69
  ):
70
  print("Using model:", model)
71
- # if model == "2k":
72
- # pipe.vae.enable_tiling(True)
73
- # pipe.set_adapters("2k")
74
- # # elif model == "4k":
75
- # pipe.vae.enable_tiling(True)
76
- # pipe.set_adapters(f"4k_{flux_model}")
77
- pipe = pipe if model == "schnell" else pipe2
78
 
79
  if randomize_seed:
80
  seed = random.randint(0, MAX_SEED)
@@ -138,14 +125,14 @@ with gr.Blocks(css=css) as demo:
138
 
139
  gr.Markdown("### Setting:")
140
 
141
- model = gr.Radio(
142
- label="Model",
143
- choices=[
144
- ("FLUX.1 dev", "dev"),
145
- ("FLUX.1 schnell", "schnell"),
146
- ],
147
- value="2k",
148
- )
149
 
150
  with gr.Row():
151
  width = gr.Slider(
@@ -179,7 +166,7 @@ with gr.Blocks(css=css) as demo:
179
  minimum=1,
180
  maximum=50,
181
  step=1,
182
- value=4, # Replace with defaults that work for your model
183
  )
184
 
185
  with gr.Column(elem_id="col2"):
@@ -190,12 +177,12 @@ with gr.Blocks(css=css) as demo:
190
  fn=infer,
191
  inputs=[
192
  prompt,
 
193
  seed,
194
  randomize_seed,
195
  width,
196
  height,
197
  num_inference_steps,
198
- model,
199
  ],
200
  outputs=[result, seed],
201
  )
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
 
5
  import spaces
6
  from pipeline_flux import FluxPipeline
7
  from transformer_flux import FluxTransformer2DModel
8
  import torch
9
 
10
+ flux_model = "dev"
11
  bfl_repo = f"black-forest-labs/FLUX.1-{flux_model}"
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  dtype = torch.bfloat16
15
 
16
  transformer = FluxTransformer2DModel.from_pretrained(
17
+ bfl_repo, subfolder="transformer", torch_dtype=dtype,
18
+ use_auth_token=os.getenv("HF_TOKEN"),
19
  )
20
  pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
21
  pipe.transformer = transformer
22
  pipe.scheduler.config.use_dynamic_shifting = False
23
  pipe.scheduler.config.time_shift = 10
24
+ # pipe.enable_model_cpu_offload()
25
  pipe = pipe.to(device)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  pipe.load_lora_weights(
28
  "Huage001/URAE",
29
  weight_name="urae_2k_adapter.safetensors",
30
  adapter_name="2k",
31
  )
32
+ pipe.load_lora_weights(
33
+ "Huage001/URAE",
34
+ weight_name="urae_4k_adapter_lora_conversion_dev.safetensors",
35
+ adapter_name="4k_dev",
36
+ )
37
+ pipe.load_lora_weights(
38
+ "Huage001/URAE",
39
+ weight_name="urae_4k_adapter_lora_conversion_schnell.safetensors",
40
+ adapter_name="4k_schnell",
41
+ )
42
  MAX_SEED = np.iinfo(np.int32).max
43
  MAX_IMAGE_SIZE = 4096
44
  USE_ZERO_GPU = True
 
52
  width,
53
  height,
54
  num_inference_steps,
 
55
  progress=gr.Progress(track_tqdm=True),
56
+ model='2k',
57
  ):
58
  print("Using model:", model)
59
+ if model == "2k":
60
+ pipe.vae.enable_tiling(True)
61
+ pipe.set_adapters("2k")
62
+ elif model == "4k":
63
+ pipe.vae.enable_tiling(True)
64
+ pipe.set_adapters(f"4k_{flux_model}")
 
65
 
66
  if randomize_seed:
67
  seed = random.randint(0, MAX_SEED)
 
125
 
126
  gr.Markdown("### Setting:")
127
 
128
+ # model = gr.Radio(
129
+ # label="Model",
130
+ # choices=[
131
+ # ("2K model", "2k"),
132
+ # ("4K model (beta)", "4k"),
133
+ # ],
134
+ # value="2k",
135
+ # )
136
 
137
  with gr.Row():
138
  width = gr.Slider(
 
166
  minimum=1,
167
  maximum=50,
168
  step=1,
169
+ value=20, # Replace with defaults that work for your model
170
  )
171
 
172
  with gr.Column(elem_id="col2"):
 
177
  fn=infer,
178
  inputs=[
179
  prompt,
180
+ # model,
181
  seed,
182
  randomize_seed,
183
  width,
184
  height,
185
  num_inference_steps,
 
186
  ],
187
  outputs=[result, seed],
188
  )