ginipick commited on
Commit
815fecf
·
verified ·
1 Parent(s): b365a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -10
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import uuid
3
  import gradio as gr
@@ -36,7 +38,125 @@ english_labels = {
36
  "Seed": "Seed"
37
  }
38
 
39
- # [Rest of the imports and pipeline setup remains the same...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  css = """
42
  footer {
@@ -127,13 +247,6 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
127
  interactive=True,
128
  randomize=True)
129
 
130
- # Updated examples with English text
131
- examples = [
132
- ["flower in mountain", "spring", "winter", 1.5],
133
- ["man", "baby", "elderly", 2.5],
134
- ["a tomato", "super fresh", "rotten", 2.5]
135
- ]
136
-
137
  examples_gradio = gr.Examples(
138
  examples=examples,
139
  inputs=[prompt, concept_1, concept_2, x],
@@ -143,7 +256,30 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
143
  cache_examples="lazy"
144
  )
145
 
146
- # [Rest of the event handlers remain the same...]
147
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  if __name__ == "__main__":
149
  demo.launch()
 
1
+
2
+
3
  import os
4
  import uuid
5
  import gradio as gr
 
38
  "Seed": "Seed"
39
  }
40
 
41
+ # load pipelines
42
+ base_model = "black-forest-labs/FLUX.1-schnell"
43
+
44
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
45
+ pipe = FluxPipeline.from_pretrained(base_model,
46
+ vae=taef1,
47
+ torch_dtype=torch.bfloat16)
48
+
49
+ pipe.transformer.to(memory_format=torch.channels_last)
50
+ clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
51
+
52
+ MAX_SEED = 2**32-1
53
+
54
+ def save_images_with_unique_filenames(image_list, save_directory):
55
+ if not os.path.exists(save_directory):
56
+ os.makedirs(save_directory)
57
+
58
+ paths = []
59
+ for image in image_list:
60
+ unique_filename = f"{uuid.uuid4()}.png"
61
+ file_path = os.path.join(save_directory, unique_filename)
62
+
63
+ image.save(file_path)
64
+ paths.append(file_path)
65
+
66
+ return paths
67
+
68
+ def convert_to_centered_scale(num):
69
+ if num % 2 == 0: # even
70
+ start = -(num // 2 - 1)
71
+ end = num // 2
72
+ else: # odd
73
+ start = -(num // 2)
74
+ end = num // 2
75
+ return tuple(range(start, end + 1))
76
+
77
+ def translate_if_korean(text):
78
+ if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
79
+ return translator(text)[0]['translation_text']
80
+ return text
81
+
82
+ @spaces.GPU(duration=85)
83
+ def generate(prompt,
84
+ concept_1,
85
+ concept_2,
86
+ scale,
87
+ randomize_seed=True,
88
+ seed=42,
89
+ recalc_directions=True,
90
+ iterations=200,
91
+ steps=3,
92
+ interm_steps=33,
93
+ guidance_scale=3.5,
94
+ x_concept_1="", x_concept_2="",
95
+ avg_diff_x=None,
96
+ total_images=[],
97
+ progress=gr.Progress()
98
+ ):
99
+ # Translate prompt and concepts if Korean
100
+ prompt = translate_if_korean(prompt)
101
+ concept_1 = translate_if_korean(concept_1)
102
+ concept_2 = translate_if_korean(concept_2)
103
+
104
+ print(f"Prompt: {prompt}, ← {concept_2}, {concept_1} ➡️ . scale {scale}, interm steps {interm_steps}")
105
+ slider_x = [concept_2, concept_1]
106
+ # check if avg diff for directions need to be re-calculated
107
+ if randomize_seed:
108
+ seed = random.randint(0, MAX_SEED)
109
+
110
+ if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
111
+ progress(0, desc="Calculating directions...")
112
+ avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
113
+ x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
114
+
115
+ images = []
116
+ high_scale = scale
117
+ low_scale = -1 * scale
118
+ for i in progress.tqdm(range(interm_steps), desc="Generating images"):
119
+ cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
120
+ image = clip_slider.generate(prompt,
121
+ width=768,
122
+ height=768,
123
+ guidance_scale=guidance_scale,
124
+ scale=cur_scale, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
125
+ images.append(image)
126
+ canvas = Image.new('RGB', (256*interm_steps, 256))
127
+ for i, im in enumerate(images):
128
+ canvas.paste(im.resize((256,256)), (256 * i, 0))
129
+
130
+ comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
131
+
132
+ scale_total = convert_to_centered_scale(interm_steps)
133
+ scale_min = scale_total[0]
134
+ scale_max = scale_total[-1]
135
+ scale_middle = scale_total.index(0)
136
+ post_generation_slider_update = gr.update(label=comma_concepts_x, value=0, minimum=scale_min, maximum=scale_max, interactive=True)
137
+ avg_diff_x = avg_diff.cpu()
138
+
139
+ video_path = f"{uuid.uuid4()}.mp4"
140
+ print(video_path)
141
+ return x_concept_1,x_concept_2, avg_diff_x, export_to_video(images, video_path, fps=5), canvas, images, images[scale_middle], post_generation_slider_update, seed
142
+
143
+ def update_pre_generated_images(slider_value, total_images):
144
+ number_images = len(total_images)
145
+ if(number_images > 0):
146
+ scale_tuple = convert_to_centered_scale(number_images)
147
+ return total_images[scale_tuple.index(slider_value)][0]
148
+ else:
149
+ return None
150
+
151
+ def reset_recalc_directions():
152
+ return True
153
+
154
+ # Updated examples with English text
155
+ examples = [
156
+ ["flower in mountain", "spring", "winter", 1.5],
157
+ ["man", "baby", "elderly", 2.5],
158
+ ["a tomato", "super fresh", "rotten", 2.5]
159
+ ]
160
 
161
  css = """
162
  footer {
 
247
  interactive=True,
248
  randomize=True)
249
 
 
 
 
 
 
 
 
250
  examples_gradio = gr.Examples(
251
  examples=examples,
252
  inputs=[prompt, concept_1, concept_2, x],
 
256
  cache_examples="lazy"
257
  )
258
 
259
+ submit.click(
260
+ fn=generate,
261
+ inputs=[prompt, concept_1, concept_2, x, randomize_seed, seed, recalc_directions,
262
+ iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2,
263
+ avg_diff_x, total_images],
264
+ outputs=[x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images,
265
+ post_generation_image, post_generation_slider, seed]
266
+ )
267
+ iterations.change(
268
+ fn=reset_recalc_directions,
269
+ outputs=[recalc_directions]
270
+ )
271
+ seed.change(
272
+ fn=reset_recalc_directions,
273
+ outputs=[recalc_directions]
274
+ )
275
+ post_generation_slider.change(
276
+ fn=update_pre_generated_images,
277
+ inputs=[post_generation_slider, total_images],
278
+ outputs=[post_generation_image],
279
+ queue=False,
280
+ show_progress="hidden",
281
+ concurrency_limit=None
282
+ )
283
+
284
  if __name__ == "__main__":
285
  demo.launch()