prithivMLmods commited on
Commit
acd3e96
·
verified ·
1 Parent(s): 0e6c9cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -35
app.py CHANGED
@@ -8,10 +8,9 @@ from PIL import Image
8
  import spaces
9
  import torch
10
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
11
 
12
  DESCRIPTIONx = """
13
-
14
-
15
  """
16
 
17
  css = '''
@@ -22,17 +21,10 @@ footer {
22
  }
23
  '''
24
 
25
- #examples = [
26
- # "3d image, cute girl, in the style of Pixar --ar 1:2 --stylize 750, 4K resolution highlights, Sharp focus, octane render, ray tracing, Ultra-High-Definition, 8k, UHD, HDR, (Masterpiece:1.5), (best quality:1.5)",
27
- # "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic oil --ar 2:3 --q 2 --s 750 --v 5 --ar 2:3 --q 2 --s 750 --v 5",
28
- # "Illustration of A starry night camp in the mountains. Low-angle view, Minimal background, Geometric shapes theme, Pottery, Split-complementary colors, Bicolored light, UHD",
29
- # "Man in brown leather jacket posing for camera, in the style of sleek and stylized, clockpunk, subtle shades, exacting precision, ferrania p30 --ar 67:101 --v 5",
30
- # "Commercial photography, giant burger, white lighting, studio light, 8k octane rendering, high resolution photography, insanely detailed, fine details, on white isolated plain, 8k, commercial photography, stock photo, professional color grading, --v 4 --ar 9:16 "
31
- #]
32
-
33
  MODEL_OPTIONS = {
34
  "Lightning": "SG161222/RealVisXL_V4.0_Lightning",
35
  "Realvision": "SG161222/RealVisXL_V4.0",
 
36
  }
37
 
38
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -43,23 +35,29 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
43
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
 
45
  def load_and_prepare_model(model_id):
46
- pipe = StableDiffusionXLPipeline.from_pretrained(
47
- model_id,
48
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
- use_safetensors=True,
50
- add_watermarker=False,
51
- ).to(device)
52
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
53
-
 
 
 
 
 
 
54
  if USE_TORCH_COMPILE:
55
  pipe.compile()
56
-
57
  if ENABLE_CPU_OFFLOAD:
58
  pipe.enable_model_cpu_offload()
59
-
60
  return pipe
61
 
62
- # Preload and compile both models
63
  models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
64
 
65
  MAX_SEED = np.iinfo(np.int32).max
@@ -92,7 +90,7 @@ def generate(
92
  ):
93
  global models
94
  pipe = models[model_choice]
95
-
96
  seed = int(randomize_seed_fn(seed, randomize_seed))
97
  generator = torch.Generator(device=device).manual_seed(seed)
98
 
@@ -139,7 +137,7 @@ def load_predefined_images():
139
  return predefined_images
140
 
141
  with gr.Blocks(css=css) as demo:
142
- gr.Markdown(DESCRIPTIONx)
143
  with gr.Row():
144
  prompt = gr.Text(
145
  label="Prompt",
@@ -150,7 +148,7 @@ with gr.Blocks(css=css) as demo:
150
  container=False,
151
  )
152
  run_button = gr.Button("Run⚡", scale=0)
153
- result = gr.Gallery(label="Result", columns=1, show_label=False)
154
 
155
  with gr.Row():
156
  model_choice = gr.Dropdown(
@@ -217,19 +215,13 @@ with gr.Blocks(css=css) as demo:
217
  value=20,
218
  )
219
 
220
- # gr.Examples(
221
- # examples=examples,
222
- # inputs=prompt,
223
- # cache_examples=False
224
- #)
225
-
226
  use_negative_prompt.change(
227
  fn=lambda x: gr.update(visible=x),
228
  inputs=use_negative_prompt,
229
  outputs=negative_prompt,
230
  api_name=False,
231
  )
232
-
233
  gr.on(
234
  triggers=[
235
  prompt.submit,
@@ -254,9 +246,5 @@ with gr.Blocks(css=css) as demo:
254
  api_name="run",
255
  )
256
 
257
- # with gr.Column(scale=3):
258
- # gr.Markdown("### Image Gallery")
259
- # predefined_gallery = gr.Gallery(label="Image Gallery", columns=4, show_label=False, value=load_predefined_images())
260
-
261
  if __name__ == "__main__":
262
  demo.queue(max_size=40).launch(show_api=False)
 
8
  import spaces
9
  import torch
10
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
11
+ from fal import AuraFlowPipeline
12
 
13
  DESCRIPTIONx = """
 
 
14
  """
15
 
16
  css = '''
 
21
  }
22
  '''
23
 
 
 
 
 
 
 
 
 
24
  MODEL_OPTIONS = {
25
  "Lightning": "SG161222/RealVisXL_V4.0_Lightning",
26
  "Realvision": "SG161222/RealVisXL_V4.0",
27
+ "AuraFlow": "fal/AuraFlow",
28
  }
29
 
30
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
 
37
  def load_and_prepare_model(model_id):
38
+ if model_id == "fal/AuraFlow":
39
+ pipe = AuraFlowPipeline.from_pretrained(
40
+ model_id,
41
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
42
+ ).to(device)
43
+ else:
44
+ pipe = StableDiffusionXLPipeline.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
47
+ use_safetensors=True,
48
+ add_watermarker=False,
49
+ ).to(device)
50
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
51
+
52
  if USE_TORCH_COMPILE:
53
  pipe.compile()
54
+
55
  if ENABLE_CPU_OFFLOAD:
56
  pipe.enable_model_cpu_offload()
57
+
58
  return pipe
59
 
60
+ # Preload and compile all models
61
  models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
62
 
63
  MAX_SEED = np.iinfo(np.int32).max
 
90
  ):
91
  global models
92
  pipe = models[model_choice]
93
+
94
  seed = int(randomize_seed_fn(seed, randomize_seed))
95
  generator = torch.Generator(device=device).manual_seed(seed)
96
 
 
137
  return predefined_images
138
 
139
  with gr.Blocks(css=css) as demo:
140
+ gr.Markdown(DESCRIPTIONx)
141
  with gr.Row():
142
  prompt = gr.Text(
143
  label="Prompt",
 
148
  container=False,
149
  )
150
  run_button = gr.Button("Run⚡", scale=0)
151
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
152
 
153
  with gr.Row():
154
  model_choice = gr.Dropdown(
 
215
  value=20,
216
  )
217
 
 
 
 
 
 
 
218
  use_negative_prompt.change(
219
  fn=lambda x: gr.update(visible=x),
220
  inputs=use_negative_prompt,
221
  outputs=negative_prompt,
222
  api_name=False,
223
  )
224
+
225
  gr.on(
226
  triggers=[
227
  prompt.submit,
 
246
  api_name="run",
247
  )
248
 
 
 
 
 
249
  if __name__ == "__main__":
250
  demo.queue(max_size=40).launch(show_api=False)