LPX commited on
Commit
15c3cb0
·
1 Parent(s): 0749e05

temp: attempt to upload half-precision .safetensors file

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. float16.py +330 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
- app_file: app_kontext.py
9
  pinned: true
10
  short_description: Inspired by our 8-Step FLUX Merged/Fusion Models
11
  ---
 
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
+ app_file: float16.py
9
  pinned: true
10
  short_description: Inspired by our 8-Step FLUX Merged/Fusion Models
11
  ---
float16.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import random
6
+ import os
7
+ import subprocess
8
+ import logging
9
+ import safetensors
10
+ #####################################################
11
+ # Forced Diffusers upgrade when cache was being stubborn; probably not needed now
12
+ # force = subprocess.run("pip install -U diffusers", shell=True)
13
+ # force = subprocess.run("pip install git+https://github.com/huggingface/diffusers.git", shell=True)
14
+ # force = subprocess.run("pip install git+https://github.com/huggingface/transformers.git", shell=True)
15
+ force = subprocess.run("git lfs install", shell=True)
16
+
17
+ #####################################################
18
+ import transformers
19
+ import diffusers
20
+ from diffusers import DiffusionPipeline
21
+ import bitsandbytes
22
+ from diffusers.quantizers import PipelineQuantizationConfig
23
+ from diffusers.utils import load_image
24
+ from diffusers import FluxKontextPipeline
25
+ from PIL import Image
26
+ from huggingface_hub import hf_hub_download
27
+ from huggingface_hub import create_repo, upload_folder
28
+ from huggingface_hub.utils._runtime import dump_environment_info
29
+ from safetensors import safe_open
30
+
31
+ #####################################################
32
+
33
+ MAX_SEED = np.iinfo(np.int32).max
34
+ API_TOKEN = os.environ['HF_TOKEN']
35
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
38
+ os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
39
+
40
+ dump_environment_info()
41
+ logging.basicConfig(level=logging.DEBUG)
42
+ logger = logging.getLogger(__name__)
43
+
44
+ #####################################################
45
+
46
+ # TESTING TWO QUANTIZATION METHODS
47
+ # 1) If FP8 is supported; `torchao` for quantization
48
+ # quant_config = PipelineQuantizationConfig(
49
+ # quant_backend="torchao",
50
+ # quant_kwargs={"quant_type": "float8dq_e4m3_row"},
51
+ # components_to_quantize=["transformer"]
52
+ # )
53
+ # 2) Otherwise, standard 4-bit quantization with bitsandbytes
54
+ # quant_config = PipelineQuantizationConfig(
55
+ # quant_backend="bitsandbytes_4bit",
56
+ # quant_kwargs={"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_quant_type": "nf4"},
57
+ # components_to_quantize=["transformer"]
58
+ # )
59
+
60
+ try:
61
+ # Set max memory usage for ZeroGPU
62
+ torch.cuda.set_per_process_memory_fraction(1.0)
63
+ torch.set_float32_matmul_precision("high")
64
+ except Exception as e:
65
+ print(f"Error setting memory usage: {e}")
66
+
67
+ #####################################################
68
+ # Load the pipeline with the specified quantization configuration.
69
+ # We use bfloat16 as the base dtype for mixed-precision inference.
70
+ # HF Spaces VRAM (50 GB) is sufficient to hold the entire pipeline (31.424 GB),
71
+ # Leave the entire pipeline to the GPU for the best performance.
72
+
73
+ # FLUX.1 Dev Kontext Lightning Model / 8-Steps
74
+ kontext_model = "LPX55/FLUX.1_Kontext-Lightning"
75
+ pipe = FluxKontextPipeline.from_pretrained(
76
+ "LPX55/FLUX.1_Kontext-Lightning",
77
+ torch_dtype=torch.float16
78
+ ).to("cuda")
79
+ # Save as a single `.safetensors` file
80
+ pipe.save_pretrained(
81
+ "./flux_16bit",
82
+ safe_serialization=True,
83
+ max_shard_size="100GB" # Forces all shards into one file (no split files)
84
+ )
85
+
86
+ local_folder = "./flux_16bit"
87
+ hub_repo_name = "LPX55/FLUX.1_Kontext-Lightning"
88
+
89
+ # create_repo(hub_repo_name, exist_ok=True, private=False)
90
+
91
+ with safe_open("./flux_16bit/model.safetensors", framework="pt", device="cuda") as f:
92
+ for k in f.keys():
93
+ print(k, f.get_slice(k).shape)
94
+
95
+ upload_folder(
96
+ folder_path=local_folder,
97
+ path_in_repo="float16",
98
+ repo_id=hub_repo_name,
99
+ repo_type="model",
100
+ commit_message="Upload half-precision FLUX.1 Kontext Lightning model",
101
+ token=API_TOKEN
102
+ )
103
+ ###################################################
104
+ # SECTION FOR LORA(S); SKIP FOR NOW
105
+
106
+ # try:
107
+ # repo_name = ""
108
+ # ckpt_name = ""
109
+ # pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name), adapter_name="A1")
110
+ # pipe.set_adapters(["A1"], adapter_weights=[0.5])
111
+ # pipe.fuse_lora(adapter_names=["A1"], lora_scale=1.0)
112
+ # pipe.unload_lora_weights()
113
+
114
+ # except Exception as e:
115
+ # print(f"Error while loading Lora: {e}")
116
+
117
+ #####################################################
118
+ def concatenate_images(images, direction="horizontal"):
119
+ """
120
+ Concatenate multiple PIL images either horizontally or vertically.
121
+
122
+ Args:
123
+ images: List of PIL Images
124
+ direction: "horizontal" or "vertical"
125
+
126
+ Returns:
127
+ PIL Image: Concatenated image
128
+ """
129
+ if not images:
130
+ return None
131
+
132
+ # Filter out None images
133
+ valid_images = [img for img in images if img is not None]
134
+
135
+ if not valid_images:
136
+ return None
137
+
138
+ if len(valid_images) == 1:
139
+ return valid_images[0].convert("RGB")
140
+
141
+ # Convert all images to RGB
142
+ valid_images = [img.convert("RGB") for img in valid_images]
143
+
144
+ if direction == "horizontal":
145
+ # Calculate total width and max height
146
+ total_width = sum(img.width for img in valid_images)
147
+ max_height = max(img.height for img in valid_images)
148
+
149
+ # Create new image
150
+ concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
151
+
152
+ # Paste images
153
+ x_offset = 0
154
+ for img in valid_images:
155
+ # Center image vertically if heights differ
156
+ y_offset = (max_height - img.height) // 2
157
+ concatenated.paste(img, (x_offset, y_offset))
158
+ x_offset += img.width
159
+
160
+ else: # vertical
161
+ # Calculate max width and total height
162
+ max_width = max(img.width for img in valid_images)
163
+ total_height = sum(img.height for img in valid_images)
164
+
165
+ # Create new image
166
+ concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
167
+
168
+ # Paste images
169
+ y_offset = 0
170
+ for img in valid_images:
171
+ # Center image horizontally if widths differ
172
+ x_offset = (max_width - img.width) // 2
173
+ concatenated.paste(img, (x_offset, y_offset))
174
+ y_offset += img.height
175
+
176
+ return concatenated
177
+
178
+ @spaces.GPU
179
+ @torch.no_grad()
180
+ def infer(input_images, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=8, width=1024, height=1024, progress=gr.Progress(track_tqdm=True)):
181
+
182
+ if randomize_seed:
183
+ seed = random.randint(0, MAX_SEED)
184
+
185
+ # Handle input_images - it could be a single image or a list of images
186
+ if input_images is None:
187
+ raise gr.Error("Please upload at least one image.")
188
+
189
+ # If it's a single image (not a list), convert to list
190
+ if not isinstance(input_images, list):
191
+ input_images = [input_images]
192
+
193
+ # Filter out None images
194
+ valid_images = [img[0] for img in input_images if img is not None]
195
+
196
+ if not valid_images:
197
+ raise gr.Error("Please upload at least one valid image.")
198
+
199
+ # Concatenate images horizontally
200
+ concatenated_image = concatenate_images(valid_images, "horizontal")
201
+
202
+ if concatenated_image is None:
203
+ raise gr.Error("Failed to process the input images.")
204
+
205
+ # original_width, original_height = concatenated_image.size
206
+
207
+ # if original_width >= original_height:
208
+ # new_width = 1024
209
+ # new_height = int(original_height * (new_width / original_width))
210
+ # new_height = round(new_height / 64) * 64
211
+ # else:
212
+ # new_height = 1024
213
+ # new_width = int(original_width * (new_height / original_height))
214
+ # new_width = round(new_width / 64) * 64
215
+
216
+ #concatenated_image_resized = concatenated_image.resize((new_width, new_height), Image.LANCZOS)
217
+
218
+ final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources."
219
+
220
+ image = pipe(
221
+ image=concatenated_image,
222
+ prompt=final_prompt,
223
+ guidance_scale=guidance_scale,
224
+ width=width,
225
+ height=height,
226
+ max_area=width * height,
227
+ num_inference_steps=steps,
228
+ generator=torch.Generator().manual_seed(seed),
229
+ ).images[0]
230
+
231
+ return image, seed, gr.update(visible=True)
232
+
233
+ css="""
234
+ #col-container {
235
+ margin: 0 auto;
236
+ max-width: 86vw;
237
+ }
238
+ """
239
+
240
+ with gr.Blocks(css=css) as demo:
241
+
242
+ with gr.Column(elem_id="col-container"):
243
+ gr.Markdown(f"""# FLUX.1 Kontext | Lightning 8-Step Model ⚡
244
+ """)
245
+ with gr.Row():
246
+ with gr.Column():
247
+ input_images = gr.Gallery(
248
+ label="Upload image(s) for editing",
249
+ show_label=True,
250
+ elem_id="gallery_input",
251
+ columns=3,
252
+ rows=2,
253
+ object_fit="contain",
254
+ height="auto",
255
+ file_types=['image'],
256
+ type='pil'
257
+ )
258
+
259
+ with gr.Row():
260
+ prompt = gr.Text(
261
+ label="Prompt",
262
+ show_label=False,
263
+ max_lines=1,
264
+ placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
265
+ container=False,
266
+ )
267
+ run_button = gr.Button("Run", scale=0)
268
+
269
+ with gr.Accordion("Advanced Settings", open=True):
270
+
271
+ with gr.Group():
272
+ width = gr.Slider(
273
+ label="W",
274
+ minimum=512,
275
+ maximum=2560,
276
+ step=64,
277
+ value=1024,
278
+ )
279
+
280
+ height = gr.Slider(
281
+ label="H",
282
+ minimum=512,
283
+ maximum=2560,
284
+ step=64,
285
+ value=1024,
286
+ )
287
+
288
+ seed = gr.Slider(
289
+ label="Seed",
290
+ minimum=0,
291
+ maximum=MAX_SEED,
292
+ step=1,
293
+ value=0,
294
+ )
295
+
296
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
297
+
298
+ guidance_scale = gr.Slider(
299
+ label="Guidance Scale",
300
+ minimum=1,
301
+ maximum=10,
302
+ step=0.1,
303
+ value=2.5,
304
+ )
305
+ input_steps = gr.Slider(
306
+ label="Steps",
307
+ minimum=1,
308
+ maximum=30,
309
+ step=1,
310
+ value=16,
311
+ )
312
+
313
+ with gr.Column():
314
+ result = gr.Image(label="Result", show_label=False, interactive=False)
315
+ reuse_button = gr.Button("Reuse this image", visible=False)
316
+
317
+ gr.on(
318
+ triggers=[run_button.click, prompt.submit],
319
+ fn = infer,
320
+ inputs = [input_images, prompt, seed, randomize_seed, guidance_scale, input_steps, width, height],
321
+ outputs = [result, seed, reuse_button]
322
+ )
323
+
324
+ reuse_button.click(
325
+ fn = lambda image: [image] if image is not None else [], # Convert single image to list for gallery
326
+ inputs = [result],
327
+ outputs = [input_images]
328
+ )
329
+
330
+ demo.queue().launch()