Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,7 @@ if torch.cuda.is_available():
|
|
| 32 |
else:
|
| 33 |
print("GPU não disponível. Usando CPU.")
|
| 34 |
|
| 35 |
-
#
|
| 36 |
def download_models():
|
| 37 |
models = [
|
| 38 |
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
|
|
@@ -47,11 +47,11 @@ def download_models():
|
|
| 47 |
for repo_id, filename, model_type in models:
|
| 48 |
model_dir = os.path.join(BASE_DIR, "models", model_type)
|
| 49 |
os.makedirs(model_dir, exist_ok=True)
|
| 50 |
-
print(f"
|
| 51 |
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir)
|
| 52 |
folder_paths.add_model_folder_path(model_type, model_dir)
|
| 53 |
|
| 54 |
-
#
|
| 55 |
def import_custom_nodes():
|
| 56 |
import asyncio
|
| 57 |
import execution
|
|
@@ -65,19 +65,22 @@ def import_custom_nodes():
|
|
| 65 |
execution.PromptQueue(server_instance)
|
| 66 |
init_extra_nodes()
|
| 67 |
|
| 68 |
-
#
|
| 69 |
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
|
| 70 |
import_custom_nodes()
|
| 71 |
|
| 72 |
try:
|
| 73 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
| 74 |
# Load CLIP
|
| 75 |
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
|
| 76 |
dualcliploader_loaded = dualcliploader.load_clip(
|
| 77 |
clip_name1="t5xxl_fp16.safetensors",
|
| 78 |
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
|
| 79 |
type="flux",
|
| 80 |
-
device=
|
| 81 |
)
|
| 82 |
|
| 83 |
# Text Encoding
|
|
@@ -202,7 +205,7 @@ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_fact
|
|
| 202 |
print(f"Error during generation: {str(e)}")
|
| 203 |
return None
|
| 204 |
|
| 205 |
-
#
|
| 206 |
with gr.Blocks() as app:
|
| 207 |
gr.Markdown("# FLUX Redux Image Generator")
|
| 208 |
|
|
@@ -299,6 +302,5 @@ with gr.Blocks() as app:
|
|
| 299 |
)
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
| 302 |
-
#
|
| 303 |
-
download_models()
|
| 304 |
app.launch(share=True)
|
|
|
|
| 32 |
else:
|
| 33 |
print("GPU não disponível. Usando CPU.")
|
| 34 |
|
| 35 |
+
# 5. Download de Modelos
|
| 36 |
def download_models():
|
| 37 |
models = [
|
| 38 |
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
|
|
|
|
| 47 |
for repo_id, filename, model_type in models:
|
| 48 |
model_dir = os.path.join(BASE_DIR, "models", model_type)
|
| 49 |
os.makedirs(model_dir, exist_ok=True)
|
| 50 |
+
print(f"Baixando {filename} de {repo_id}...")
|
| 51 |
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir)
|
| 52 |
folder_paths.add_model_folder_path(model_type, model_dir)
|
| 53 |
|
| 54 |
+
# 6. Load custom nodes
|
| 55 |
def import_custom_nodes():
|
| 56 |
import asyncio
|
| 57 |
import execution
|
|
|
|
| 65 |
execution.PromptQueue(server_instance)
|
| 66 |
init_extra_nodes()
|
| 67 |
|
| 68 |
+
# 7. Main function to execute the workflow and generate an image
|
| 69 |
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
|
| 70 |
import_custom_nodes()
|
| 71 |
|
| 72 |
try:
|
| 73 |
with torch.inference_mode():
|
| 74 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 75 |
+
print(f"Using device: {device}")
|
| 76 |
+
|
| 77 |
# Load CLIP
|
| 78 |
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
|
| 79 |
dualcliploader_loaded = dualcliploader.load_clip(
|
| 80 |
clip_name1="t5xxl_fp16.safetensors",
|
| 81 |
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
|
| 82 |
type="flux",
|
| 83 |
+
device=device
|
| 84 |
)
|
| 85 |
|
| 86 |
# Text Encoding
|
|
|
|
| 205 |
print(f"Error during generation: {str(e)}")
|
| 206 |
return None
|
| 207 |
|
| 208 |
+
# 8. Gradio Interface
|
| 209 |
with gr.Blocks() as app:
|
| 210 |
gr.Markdown("# FLUX Redux Image Generator")
|
| 211 |
|
|
|
|
| 302 |
)
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
+
# Download_models()
|
|
|
|
| 306 |
app.launch(share=True)
|