r3gm commited on
Commit
66ab4b2
·
verified ·
1 Parent(s): f520a86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -56
app.py CHANGED
@@ -1,47 +1,126 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import random
4
  from diffusers import DiffusionPipeline
5
  import torch
 
 
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  if torch.cuda.is_available():
10
  torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
 
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
 
46
  css="""
47
  #col-container {
@@ -59,33 +138,36 @@ with gr.Blocks(css=css) as demo:
59
 
60
  with gr.Column(elem_id="col-container"):
61
  gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
 
63
  Currently running on {power_device}.
64
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  with gr.Row():
67
 
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
  run_button = gr.Button("Run", scale=0)
77
 
78
  result = gr.Image(label="Result", show_label=False)
79
 
80
  with gr.Accordion("Advanced Settings", open=False):
81
 
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
  seed = gr.Slider(
90
  label="Seed",
91
  minimum=0,
@@ -103,7 +185,7 @@ with gr.Blocks(css=css) as demo:
103
  minimum=256,
104
  maximum=MAX_IMAGE_SIZE,
105
  step=32,
106
- value=512,
107
  )
108
 
109
  height = gr.Slider(
@@ -111,7 +193,7 @@ with gr.Blocks(css=css) as demo:
111
  minimum=256,
112
  maximum=MAX_IMAGE_SIZE,
113
  step=32,
114
- value=512,
115
  )
116
 
117
  with gr.Row():
@@ -119,28 +201,32 @@ with gr.Blocks(css=css) as demo:
119
  guidance_scale = gr.Slider(
120
  label="Guidance scale",
121
  minimum=0.0,
122
- maximum=10.0,
123
  step=0.1,
124
- value=0.0,
125
  )
126
 
127
  num_inference_steps = gr.Slider(
128
  label="Number of inference steps",
129
  minimum=1,
130
- maximum=12,
131
  step=1,
132
- value=2,
133
  )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
 
 
 
139
 
140
  run_button.click(
141
  fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
 
144
  )
145
 
146
  demo.queue().launch()
 
1
  import gradio as gr
2
+ import spaces
3
  import numpy as np
4
  import random
5
  from diffusers import DiffusionPipeline
6
  import torch
7
+ import threading
8
+ from PIL import Image
9
 
10
+ MODEL_ID = "cagliostrolab/animagine-xl-3.1"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  if torch.cuda.is_available():
14
  torch.cuda.max_memory_allocated(device=device)
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype=torch.float16,
18
+ use_safetensors=True,
19
+ )
20
+ else:
21
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, use_safetensors=True)
22
+ pipe = pipe.to(device)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 1024 + 512
26
 
27
+ def latents_to_rgb(latents):
28
+ weights = (
29
+ (60, -60, 25, -70),
30
+ (60, -5, 15, -50),
31
+ (60, 10, -5, -35)
32
+ )
33
 
34
+ weights_tensor = torch.tensor(weights, dtype=latents.dtype, device=latents.device).T
35
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype, device=latents.device)
36
+ rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.view(-1, 1, 1)
37
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
38
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
39
+
40
+ pil_image = Image.fromarray(image_array)
41
+
42
+ resized_image = pil_image.resize((pil_image.size[0] * 2, pil_image.size[1] * 2), Image.LANCZOS) # Resize 128x128 * ...
43
+ return resized_image
44
+
45
+ class BaseGenerator:
46
+ def __init__(self, pipe):
47
+ self.pipe = pipe
48
+ self.image = None
49
+ self.new_image_event = threading.Event()
50
+ self.generation_finished = threading.Event()
51
+ self.intermediate_image_concurrency(5)
52
+
53
+ def intermediate_image_concurrency(self, concurrency):
54
+ self.concurrency = concurrency
55
 
56
+ def decode_tensors(self, pipe, step, timestep, callback_kwargs):
57
+ latents = callback_kwargs["latents"]
58
+ if step % self.concurrency == 0: # every how many steps
59
+ print(step)
60
+ self.image = latents_to_rgb(latents)
61
+ self.new_image_event.set() # Signal that a new image is available
62
+ return callback_kwargs
63
+
64
+ def show_images(self):
65
+ while not self.generation_finished.is_set() or self.new_image_event.is_set():
66
+ self.new_image_event.wait() # Wait for a new image
67
+ self.new_image_event.clear() # Clear the event flag
68
+
69
+ if self.image:
70
+ yield self.image # Yield the new image
71
+
72
+ def generate_images(self, **kwargs):
73
+ if kwargs.get('randomize_seed', False):
74
+ kwargs['seed'] = random.randint(0, MAX_SEED)
75
+
76
+ generator = torch.Generator().manual_seed(kwargs['seed'])
77
 
78
+ self.image = None
79
+ self.image = self.pipe(
80
+ height=kwargs['height'],
81
+ width=kwargs['width'],
82
+ prompt=kwargs['prompt'],
83
+ negative_prompt=kwargs['negative_prompt'],
84
+ guidance_scale=kwargs['guidance_scale'],
85
+ num_inference_steps=kwargs['num_inference_steps'],
86
+ generator=generator,
87
+ callback_on_step_end=self.decode_tensors,
88
+ callback_on_step_end_tensor_inputs=["latents"],
89
+ ).images[0]
90
+ print("finish")
91
+ self.new_image_event.set() # Result image
92
+ self.generation_finished.set() # Signal that generation is finished
93
+
94
+ def stream(self, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
95
+ self.generation_finished.clear()
96
+ threading.Thread(target=self.generate_images, args=(), kwargs=dict(
97
+ prompt=prompt,
98
+ negative_prompt=negative_prompt,
99
+ seed=seed,
100
+ randomize_seed=randomize_seed,
101
+ width=width,
102
+ height=height,
103
+ guidance_scale=guidance_scale,
104
+ num_inference_steps=num_inference_steps
105
+ )).start()
106
+ return self.show_images()
107
+
108
+ image_generator = BaseGenerator(pipe)
109
+
110
+ @spaces.GPU
111
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, concurrency):
112
+
113
+ image_generator.intermediate_image_concurrency(concurrency)
114
+
115
+ stream = image_generator.stream(
116
+ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
117
+ )
118
+
119
+ yield None
120
+
121
+ for image in stream:
122
+ yield image
123
 
 
 
 
 
 
124
 
125
  css="""
126
  #col-container {
 
138
 
139
  with gr.Column(elem_id="col-container"):
140
  gr.Markdown(f"""
141
+ # Text-to-Image: Display each generation step
142
+ Gradio template for displaying preview images during generation steps
143
  Currently running on {power_device}.
144
  """)
145
+
146
+ prompt = gr.Text(
147
+ label="Prompt",
148
+ show_label=False,
149
+ max_lines=1,
150
+ placeholder="Enter your prompt",
151
+ container=False,
152
+ value="1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer, outdoors, night",
153
+ )
154
+
155
+ negative_prompt = gr.Text(
156
+ label="Negative prompt",
157
+ max_lines=1,
158
+ placeholder="Enter a negative prompt",
159
+ visible=True,
160
+ value="nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
161
+ )
162
 
163
  with gr.Row():
164
 
 
 
 
 
 
 
 
 
165
  run_button = gr.Button("Run", scale=0)
166
 
167
  result = gr.Image(label="Result", show_label=False)
168
 
169
  with gr.Accordion("Advanced Settings", open=False):
170
 
 
 
 
 
 
 
 
171
  seed = gr.Slider(
172
  label="Seed",
173
  minimum=0,
 
185
  minimum=256,
186
  maximum=MAX_IMAGE_SIZE,
187
  step=32,
188
+ value=832,
189
  )
190
 
191
  height = gr.Slider(
 
193
  minimum=256,
194
  maximum=MAX_IMAGE_SIZE,
195
  step=32,
196
+ value=1216,
197
  )
198
 
199
  with gr.Row():
 
201
  guidance_scale = gr.Slider(
202
  label="Guidance scale",
203
  minimum=0.0,
204
+ maximum=30.0,
205
  step=0.1,
206
+ value=7.0,
207
  )
208
 
209
  num_inference_steps = gr.Slider(
210
  label="Number of inference steps",
211
  minimum=1,
212
+ maximum=100,
213
  step=1,
214
+ value=76,
215
  )
216
+
217
+ concurrency_gui = gr.Slider(
218
+ label="Number of steps to show the next preview image",
219
+ minimum=1,
220
+ maximum=20,
221
+ step=1,
222
+ value=3,
223
+ )
224
 
225
  run_button.click(
226
  fn = infer,
227
+ inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, concurrency_gui],
228
+ outputs = [result],
229
+ show_progress="minimal",
230
  )
231
 
232
  demo.queue().launch()