nftnik commited on
Commit
fd40ffd
·
verified ·
1 Parent(s): f85c1ef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +269 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from nodes import NODE_CLASS_MAPPINGS
8
+ import folder_paths
9
+
10
+ # Configure base and output directories
11
+ BASE_DIR = os.path.dirname(os.path.realpath(__file__))
12
+ output_dir = os.path.join(BASE_DIR, "output")
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ folder_paths.set_output_directory(output_dir)
15
+
16
+ def import_custom_nodes():
17
+ """Loads custom nodes required for the workflow."""
18
+ import asyncio
19
+ import execution
20
+ from nodes import init_extra_nodes
21
+ import server
22
+
23
+ loop = asyncio.new_event_loop()
24
+ asyncio.set_event_loop(loop)
25
+
26
+ server_instance = server.PromptServer(loop)
27
+ execution.PromptQueue(server_instance)
28
+ init_extra_nodes()
29
+
30
+ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
31
+ """
32
+ Main function to execute the workflow and generate an image.
33
+ """
34
+ import_custom_nodes()
35
+
36
+ try:
37
+ with torch.inference_mode():
38
+ # Load CLIP
39
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
40
+ dualcliploader_loaded = dualcliploader.load_clip(
41
+ clip_name1="t5xxl_fp16.safetensors",
42
+ clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
43
+ type="flux",
44
+ device="default"
45
+ )
46
+
47
+ # Text Encoding
48
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
49
+ encoded_text = cliptextencode.encode(
50
+ text=prompt,
51
+ clip=dualcliploader_loaded[0]
52
+ )
53
+
54
+ # Load Style Model
55
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
56
+ style_model = stylemodelloader.load_style_model(
57
+ style_model_name="flux1-redux-dev.safetensors"
58
+ )
59
+
60
+ # Load CLIP Vision
61
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
62
+ clip_vision = clipvisionloader.load_clip(
63
+ clip_name="sigclip_vision_patch14_384.safetensors"
64
+ )
65
+
66
+ # Load Input Image
67
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
68
+ loaded_image = loadimage.load_image(image=input_image)
69
+
70
+ # Load VAE
71
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
72
+ vae = vaeloader.load_vae(vae_name="ae.safetensors")
73
+
74
+ # Load UNET
75
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
76
+ unet = unetloader.load_unet(
77
+ unet_name="flux1-dev.sft",
78
+ weight_dtype="fp8_e4m3fn"
79
+ )
80
+
81
+ # Load LoRA
82
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
83
+ lora_model = loraloadermodelonly.load_lora_model_only(
84
+ lora_name="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
85
+ strength_model=lora_weight,
86
+ model=unet[0]
87
+ )
88
+
89
+ # Flux Guidance
90
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
91
+ flux_guidance = fluxguidance.append(
92
+ guidance=guidance,
93
+ conditioning=encoded_text[0]
94
+ )
95
+
96
+ # Redux Advanced
97
+ reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
98
+ redux_result = reduxadvanced.apply_stylemodel(
99
+ downsampling_factor=downsampling_factor,
100
+ downsampling_function="area",
101
+ mode="keep aspect ratio",
102
+ weight=weight,
103
+ autocrop_margin=0.1,
104
+ conditioning=flux_guidance[0],
105
+ style_model=style_model[0],
106
+ clip_vision=clip_vision[0],
107
+ image=loaded_image[0]
108
+ )
109
+
110
+ # Empty Latent Image
111
+ emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
112
+ empty_latent = emptylatentimage.generate(
113
+ width=width,
114
+ height=height,
115
+ batch_size=batch_size
116
+ )
117
+
118
+ # KSampler
119
+ ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
120
+ sampled = ksampler.sample(
121
+ seed=seed,
122
+ steps=steps,
123
+ cfg=1,
124
+ sampler_name="euler",
125
+ scheduler="simple",
126
+ denoise=1,
127
+ model=lora_model[0],
128
+ positive=redux_result[0],
129
+ negative=flux_guidance[0],
130
+ latent_image=empty_latent[0]
131
+ )
132
+
133
+ # VAE Decode
134
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
135
+ decoded = vaedecode.decode(
136
+ samples=sampled[0],
137
+ vae=vae[0]
138
+ )
139
+
140
+ # Save the image in the output directory
141
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
142
+ temp_filename = f"Flux_{random.randint(0, 99999)}"
143
+ saveimage.save_images(
144
+ filename_prefix=temp_filename,
145
+ images=decoded[0]
146
+ )
147
+
148
+ # Add a delay to ensure the file system updates
149
+ import time
150
+ time.sleep(0.5)
151
+
152
+ # Dynamically retrieve the correct file name
153
+ saved_files = [f for f in os.listdir(output_dir) if f.startswith(temp_filename)]
154
+ if not saved_files:
155
+ raise FileNotFoundError(f"Output file not found: Expected files starting with {temp_filename}")
156
+
157
+ # Get the full path of the saved file
158
+ temp_path = os.path.join(output_dir, saved_files[0])
159
+ print(f"Image saved at: {temp_path}")
160
+
161
+ # Return the saved image for Gradio display
162
+ output_image = Image.open(temp_path)
163
+ return output_image
164
+
165
+ except Exception as e:
166
+ print(f"Error during generation: {str(e)}")
167
+ return None
168
+
169
+ # Gradio Interface
170
+ with gr.Blocks() as app:
171
+ gr.Markdown("# FLUX Redux Image Generator")
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ prompt_input = gr.Textbox(
176
+ label="Prompt",
177
+ placeholder="Enter your prompt here...",
178
+ lines=5
179
+ )
180
+ input_image = gr.Image(
181
+ label="Input Image",
182
+ type="filepath"
183
+ )
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ lora_weight = gr.Slider(
188
+ minimum=0,
189
+ maximum=2,
190
+ step=0.1,
191
+ value=0.6,
192
+ label="LoRA Weight"
193
+ )
194
+ guidance = gr.Slider(
195
+ minimum=0,
196
+ maximum=20,
197
+ step=0.1,
198
+ value=3.5,
199
+ label="Guidance"
200
+ )
201
+ downsampling_factor = gr.Slider(
202
+ minimum=1,
203
+ maximum=8,
204
+ step=1,
205
+ value=3,
206
+ label="Downsampling Factor"
207
+ )
208
+ weight = gr.Slider(
209
+ minimum=0,
210
+ maximum=2,
211
+ step=0.1,
212
+ value=1.0,
213
+ label="Model Weight"
214
+ )
215
+ with gr.Column():
216
+ seed = gr.Number(
217
+ value=random.randint(1, 2**64),
218
+ label="Seed",
219
+ precision=0
220
+ )
221
+ width = gr.Number(
222
+ value=1024,
223
+ label="Width",
224
+ precision=0
225
+ )
226
+ height = gr.Number(
227
+ value=1024,
228
+ label="Height",
229
+ precision=0
230
+ )
231
+ batch_size = gr.Number(
232
+ value=1,
233
+ label="Batch Size",
234
+ precision=0
235
+ )
236
+ steps = gr.Number(
237
+ value=20,
238
+ label="Steps",
239
+ precision=0
240
+ )
241
+
242
+ generate_btn = gr.Button("Generate Image")
243
+
244
+ with gr.Column():
245
+ output_image = gr.Image(label="Generated Image", type="pil")
246
+
247
+ generate_btn.click(
248
+ fn=generate_image,
249
+ inputs=[
250
+ prompt_input,
251
+ input_image,
252
+ lora_weight,
253
+ guidance,
254
+ downsampling_factor,
255
+ weight,
256
+ seed,
257
+ width,
258
+ height,
259
+ batch_size,
260
+ steps
261
+ ],
262
+ outputs=[output_image]
263
+ )
264
+
265
+ if __name__ == "__main__":
266
+ app.launch()
267
+
268
+
269
+ #python app.py
requirements.txt ADDED
File without changes