gokaygokay commited on
Commit
f064a5b
·
verified ·
1 Parent(s): 0882268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -183
app.py CHANGED
@@ -1,183 +1,197 @@
1
- import os
2
- import random
3
- import sys
4
- from typing import Sequence, Mapping, Any, Union
5
- import torch
6
- import gradio as gr
7
- from huggingface_hub import hf_hub_download
8
-
9
- # Download required models
10
- t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir="models/text_encoders/")
11
- vae_path = hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae")
12
- unet_path = hf_hub_download(repo_id="lodestones/Chroma", filename="chroma-unlocked-v31.safetensors", local_dir="models/unet")
13
-
14
- # Import the workflow functions
15
- from my_workflow import (
16
- get_value_at_index,
17
- add_comfyui_directory_to_sys_path,
18
- add_extra_model_paths,
19
- import_custom_nodes,
20
- NODE_CLASS_MAPPINGS,
21
- CLIPTextEncode,
22
- CLIPLoader,
23
- VAEDecode,
24
- UNETLoader,
25
- VAELoader,
26
- SaveImage,
27
- )
28
-
29
- # Initialize ComfyUI
30
- add_comfyui_directory_to_sys_path()
31
- add_extra_model_paths()
32
- import_custom_nodes()
33
-
34
- def generate_image(prompt, negative_prompt, width, height, steps, cfg, seed):
35
- with torch.inference_mode():
36
- # Set random seed if provided
37
- if seed == -1:
38
- seed = random.randint(1, 2**64)
39
- random.seed(seed)
40
-
41
- randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
42
- randomnoise_68 = randomnoise.get_noise(noise_seed=seed)
43
-
44
- emptysd3latentimage = NODE_CLASS_MAPPINGS["EmptySD3LatentImage"]()
45
- emptysd3latentimage_69 = emptysd3latentimage.generate(
46
- width=width, height=height, batch_size=1
47
- )
48
-
49
- ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
50
- ksamplerselect_72 = ksamplerselect.get_sampler(sampler_name="euler")
51
-
52
- cliploader = CLIPLoader()
53
- cliploader_78 = cliploader.load_clip(
54
- clip_name="t5xxl_fp8_e4m3fn.safetensors", type="chroma", device="default"
55
- )
56
-
57
- t5tokenizeroptions = NODE_CLASS_MAPPINGS["T5TokenizerOptions"]()
58
- t5tokenizeroptions_82 = t5tokenizeroptions.set_options(
59
- min_padding=1, min_length=0, clip=get_value_at_index(cliploader_78, 0)
60
- )
61
-
62
- cliptextencode = CLIPTextEncode()
63
- cliptextencode_74 = cliptextencode.encode(
64
- text=prompt,
65
- clip=get_value_at_index(t5tokenizeroptions_82, 0),
66
- )
67
-
68
- cliptextencode_75 = cliptextencode.encode(
69
- text=negative_prompt,
70
- clip=get_value_at_index(t5tokenizeroptions_82, 0),
71
- )
72
-
73
- unetloader = UNETLoader()
74
- unetloader_76 = unetloader.load_unet(
75
- unet_name="chroma-unlocked-v31.safetensors", weight_dtype="fp8_e4m3fn"
76
- )
77
-
78
- vaeloader = VAELoader()
79
- vaeloader_80 = vaeloader.load_vae(vae_name="ae.safetensors")
80
-
81
- cfgguider = NODE_CLASS_MAPPINGS["CFGGuider"]()
82
- basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
83
- samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
84
- vaedecode = VAEDecode()
85
- saveimage = SaveImage()
86
-
87
- cfgguider_73 = cfgguider.get_guider(
88
- cfg=cfg,
89
- model=get_value_at_index(unetloader_76, 0),
90
- positive=get_value_at_index(cliptextencode_74, 0),
91
- negative=get_value_at_index(cliptextencode_75, 0),
92
- )
93
-
94
- basicscheduler_84 = basicscheduler.get_sigmas(
95
- scheduler="beta",
96
- steps=steps,
97
- denoise=1,
98
- model=get_value_at_index(unetloader_76, 0),
99
- )
100
-
101
- samplercustomadvanced_67 = samplercustomadvanced.sample(
102
- noise=get_value_at_index(randomnoise_68, 0),
103
- guider=get_value_at_index(cfgguider_73, 0),
104
- sampler=get_value_at_index(ksamplerselect_72, 0),
105
- sigmas=get_value_at_index(basicscheduler_84, 0),
106
- latent_image=get_value_at_index(emptysd3latentimage_69, 0),
107
- )
108
-
109
- vaedecode_79 = vaedecode.decode(
110
- samples=get_value_at_index(samplercustomadvanced_67, 0),
111
- vae=get_value_at_index(vaeloader_80, 0),
112
- )
113
-
114
- # Instead of saving to file, return the image directly
115
- return get_value_at_index(vaedecode_79, 0)
116
-
117
- # Create Gradio interface
118
- with gr.Blocks() as app:
119
- gr.Markdown("# Chroma Image Generator")
120
-
121
- with gr.Row():
122
- with gr.Column():
123
- prompt = gr.Textbox(
124
- label="Prompt",
125
- placeholder="Enter your prompt here...",
126
- lines=3
127
- )
128
- negative_prompt = gr.Textbox(
129
- label="Negative Prompt",
130
- placeholder="Enter negative prompt here...",
131
- value="low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors",
132
- lines=2
133
- )
134
-
135
- with gr.Row():
136
- width = gr.Slider(
137
- minimum=512,
138
- maximum=2048,
139
- value=1024,
140
- step=64,
141
- label="Width"
142
- )
143
- height = gr.Slider(
144
- minimum=512,
145
- maximum=2048,
146
- value=1024,
147
- step=64,
148
- label="Height"
149
- )
150
-
151
- with gr.Row():
152
- steps = gr.Slider(
153
- minimum=1,
154
- maximum=50,
155
- value=26,
156
- step=1,
157
- label="Steps"
158
- )
159
- cfg = gr.Slider(
160
- minimum=1,
161
- maximum=20,
162
- value=4,
163
- step=0.5,
164
- label="CFG Scale"
165
- )
166
- seed = gr.Number(
167
- value=-1,
168
- label="Seed (-1 for random)"
169
- )
170
-
171
- generate_btn = gr.Button("Generate")
172
-
173
- with gr.Column():
174
- output_image = gr.Image(label="Generated Image")
175
-
176
- generate_btn.click(
177
- fn=generate_image,
178
- inputs=[prompt, negative_prompt, width, height, steps, cfg, seed],
179
- outputs=[output_image]
180
- )
181
-
182
- if __name__ == "__main__":
183
- app.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+ from comfy import model_management
11
+
12
+ # Download required models
13
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir="models/text_encoders/")
14
+ vae_path = hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae")
15
+ unet_path = hf_hub_download(repo_id="lodestones/Chroma", filename="chroma-unlocked-v31.safetensors", local_dir="models/unet")
16
+
17
+ # Import the workflow functions
18
+ from my_workflow import (
19
+ get_value_at_index,
20
+ add_comfyui_directory_to_sys_path,
21
+ add_extra_model_paths,
22
+ import_custom_nodes,
23
+ NODE_CLASS_MAPPINGS,
24
+ CLIPTextEncode,
25
+ CLIPLoader,
26
+ VAEDecode,
27
+ UNETLoader,
28
+ VAELoader,
29
+ SaveImage,
30
+ )
31
+
32
+ # Initialize ComfyUI
33
+ add_comfyui_directory_to_sys_path()
34
+ add_extra_model_paths()
35
+ import_custom_nodes()
36
+
37
+ # Initialize all model loaders outside the function
38
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
39
+ emptysd3latentimage = NODE_CLASS_MAPPINGS["EmptySD3LatentImage"]()
40
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
41
+ cliploader = CLIPLoader()
42
+ t5tokenizeroptions = NODE_CLASS_MAPPINGS["T5TokenizerOptions"]()
43
+ cliptextencode = CLIPTextEncode()
44
+ unetloader = UNETLoader()
45
+ vaeloader = VAELoader()
46
+ cfgguider = NODE_CLASS_MAPPINGS["CFGGuider"]()
47
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
48
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
49
+ vaedecode = VAEDecode()
50
+ saveimage = SaveImage()
51
+
52
+ # Load models
53
+ cliploader_78 = cliploader.load_clip(
54
+ clip_name="t5xxl_fp8_e4m3fn.safetensors", type="chroma", device="default"
55
+ )
56
+ t5tokenizeroptions_82 = t5tokenizeroptions.set_options(
57
+ min_padding=1, min_length=0, clip=get_value_at_index(cliploader_78, 0)
58
+ )
59
+ unetloader_76 = unetloader.load_unet(
60
+ unet_name="chroma-unlocked-v31.safetensors", weight_dtype="fp8_e4m3fn"
61
+ )
62
+ vaeloader_80 = vaeloader.load_vae(vae_name="ae.safetensors")
63
+
64
+ # Add all the models that load a safetensors file
65
+ model_loaders = [cliploader_78, unetloader_76, vaeloader_80]
66
+
67
+ # Check which models are valid and how to best load them
68
+ valid_models = [
69
+ getattr(loader[0], 'patcher', loader[0])
70
+ for loader in model_loaders
71
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
72
+ ]
73
+
74
+ # Finally loads the models
75
+ model_management.load_models_gpu(valid_models)
76
+
77
+ @spaces.GPU
78
+ def generate_image(prompt, negative_prompt, width, height, steps, cfg, seed):
79
+ with torch.inference_mode():
80
+ # Set random seed if provided
81
+ if seed == -1:
82
+ seed = random.randint(1, 2**64)
83
+ random.seed(seed)
84
+
85
+ randomnoise_68 = randomnoise.get_noise(noise_seed=seed)
86
+ emptysd3latentimage_69 = emptysd3latentimage.generate(
87
+ width=width, height=height, batch_size=1
88
+ )
89
+ ksamplerselect_72 = ksamplerselect.get_sampler(sampler_name="euler")
90
+
91
+ cliptextencode_74 = cliptextencode.encode(
92
+ text=prompt,
93
+ clip=get_value_at_index(t5tokenizeroptions_82, 0),
94
+ )
95
+
96
+ cliptextencode_75 = cliptextencode.encode(
97
+ text=negative_prompt,
98
+ clip=get_value_at_index(t5tokenizeroptions_82, 0),
99
+ )
100
+
101
+ cfgguider_73 = cfgguider.get_guider(
102
+ cfg=cfg,
103
+ model=get_value_at_index(unetloader_76, 0),
104
+ positive=get_value_at_index(cliptextencode_74, 0),
105
+ negative=get_value_at_index(cliptextencode_75, 0),
106
+ )
107
+
108
+ basicscheduler_84 = basicscheduler.get_sigmas(
109
+ scheduler="beta",
110
+ steps=steps,
111
+ denoise=1,
112
+ model=get_value_at_index(unetloader_76, 0),
113
+ )
114
+
115
+ samplercustomadvanced_67 = samplercustomadvanced.sample(
116
+ noise=get_value_at_index(randomnoise_68, 0),
117
+ guider=get_value_at_index(cfgguider_73, 0),
118
+ sampler=get_value_at_index(ksamplerselect_72, 0),
119
+ sigmas=get_value_at_index(basicscheduler_84, 0),
120
+ latent_image=get_value_at_index(emptysd3latentimage_69, 0),
121
+ )
122
+
123
+ vaedecode_79 = vaedecode.decode(
124
+ samples=get_value_at_index(samplercustomadvanced_67, 0),
125
+ vae=get_value_at_index(vaeloader_80, 0),
126
+ )
127
+
128
+ # Instead of saving to file, return the image directly
129
+ return get_value_at_index(vaedecode_79, 0)
130
+
131
+ # Create Gradio interface
132
+ with gr.Blocks() as app:
133
+ gr.Markdown("# Chroma Image Generator")
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ prompt = gr.Textbox(
138
+ label="Prompt",
139
+ placeholder="Enter your prompt here...",
140
+ lines=3
141
+ )
142
+ negative_prompt = gr.Textbox(
143
+ label="Negative Prompt",
144
+ placeholder="Enter negative prompt here...",
145
+ value="low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors",
146
+ lines=2
147
+ )
148
+
149
+ with gr.Row():
150
+ width = gr.Slider(
151
+ minimum=512,
152
+ maximum=2048,
153
+ value=1024,
154
+ step=64,
155
+ label="Width"
156
+ )
157
+ height = gr.Slider(
158
+ minimum=512,
159
+ maximum=2048,
160
+ value=1024,
161
+ step=64,
162
+ label="Height"
163
+ )
164
+
165
+ with gr.Row():
166
+ steps = gr.Slider(
167
+ minimum=1,
168
+ maximum=50,
169
+ value=26,
170
+ step=1,
171
+ label="Steps"
172
+ )
173
+ cfg = gr.Slider(
174
+ minimum=1,
175
+ maximum=20,
176
+ value=4,
177
+ step=0.5,
178
+ label="CFG Scale"
179
+ )
180
+ seed = gr.Number(
181
+ value=-1,
182
+ label="Seed (-1 for random)"
183
+ )
184
+
185
+ generate_btn = gr.Button("Generate")
186
+
187
+ with gr.Column():
188
+ output_image = gr.Image(label="Generated Image")
189
+
190
+ generate_btn.click(
191
+ fn=generate_image,
192
+ inputs=[prompt, negative_prompt, width, height, steps, cfg, seed],
193
+ outputs=[output_image]
194
+ )
195
+
196
+ if __name__ == "__main__":
197
+ app.launch(share=True)