Redux / app.py
nftnik's picture
Update app.py
02c8fdb verified
raw
history blame
13.2 kB
import os
import sys
import random
import torch
from pathlib import Path
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import logging
# Adicione se ainda não tiver
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes, SaveImage # <-- Node SaveImage
from comfy import model_management
import folder_paths
# Configurar logging para debug
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 1. Configuração de Caminhos e Imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 3. Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Configurar caminhos dos modelos e verificar estrutura
MODEL_FOLDERS = ["style_models", "text_encoders", "vae", "unet", "clip_vision"]
for model_folder in MODEL_FOLDERS:
folder_path = os.path.join(models_dir, model_folder)
os.makedirs(folder_path, exist_ok=True)
folder_paths.add_model_folder_path(model_folder, folder_path)
logger.info(f"Pasta de modelo configurada: {model_folder}")
# 4. Diagnóstico CUDA
logger.info(f"Python version: {sys.version}")
logger.info(f"Torch version: {torch.__version__}")
logger.info(f"CUDA disponível: {torch.cuda.is_available()}")
logger.info(f"Quantidade de GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
logger.info(f"GPU atual: {torch.cuda.get_device_name(0)}")
# 5. Inicialização do ComfyUI
logger.info("Inicializando ComfyUI...")
try:
init_extra_nodes()
except Exception as e:
logger.warning(f"Aviso na inicialização de nós extras: {str(e)}")
logger.info("Continuando mesmo com avisos nos nós extras...")
# 6. Helper Functions
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def verify_file_exists(folder: str, filename: str) -> bool:
file_path = os.path.join(models_dir, folder, filename)
exists = os.path.exists(file_path)
if not exists:
logger.error(f"Arquivo não encontrado: {file_path}")
return exists
# 7. Download de Modelos
logger.info("Baixando modelos necessários...")
try:
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Redux-dev",
filename="flux1-redux-dev.safetensors",
local_dir=os.path.join(models_dir, "style_models")
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="zer0int/CLIP-GmP-ViT-L-14",
filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="ae.safetensors",
local_dir=os.path.join(models_dir, "vae")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="flux1-dev.safetensors",
local_dir=os.path.join(models_dir, "unet")
)
hf_hub_download(
repo_id="Comfy-Org/sigclip_vision_384",
filename="sigclip_vision_patch14_384.safetensors",
local_dir=os.path.join(models_dir, "clip_vision")
)
except Exception as e:
logger.error(f"Erro ao baixar modelos: {str(e)}")
raise
# 8. Inicialização dos Modelos
logger.info("Inicializando modelos...")
try:
with torch.no_grad():
# CLIP
logger.info("Carregando CLIP...")
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
type="flux"
)
if CLIP_MODEL is None:
raise ValueError("Falha ao carregar CLIP model")
# CLIP Vision
logger.info("Carregando CLIP Vision...")
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
CLIP_VISION = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
if CLIP_VISION is None:
raise ValueError("Falha ao carregar CLIP Vision model")
# Style Model
logger.info("Carregando Style Model...")
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
if STYLE_MODEL is None:
raise ValueError("Falha ao carregar Style Model")
# VAE
logger.info("Carregando VAE...")
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAE_MODEL = vaeloader.load_vae(
vae_name="ae.safetensors"
)
if VAE_MODEL is None:
raise ValueError("Falha ao carregar VAE model")
# UNET
logger.info("Carregando UNET...")
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="fp8_e4m3fn" # ajuste se preciso
)
if UNET_MODEL is None:
raise ValueError("Falha ao carregar UNET model")
logger.info("Carregando modelos na GPU...")
model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
for loader in model_loaders
])
logger.info("Modelos carregados com sucesso")
except Exception as e:
logger.error(f"Erro ao inicializar modelos: {str(e)}")
raise
# 9. Função de Geração
@spaces.GPU
def generate_image(
prompt, input_image, lora_weight, guidance, downsampling_factor,
weight, seed, width, height, batch_size, steps,
progress=gr.Progress(track_tqdm=True)
):
try:
with torch.no_grad():
logger.info(f"Iniciando geração com prompt: {prompt}")
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=CLIP_MODEL[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
if loaded_image is None:
raise ValueError("Erro ao carregar a imagem de entrada")
logger.info("Imagem carregada com sucesso")
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=STYLE_MODEL[0],
clip_vision=CLIP_VISION[0],
image=loaded_image[0]
)
# Empty Latent
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
logger.info("Iniciando sampling...")
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=UNET_MODEL[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# VAE Decode
logger.info("Decodificando imagem...")
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=VAE_MODEL[0]
)
# ======================== SALVAR IMAGEM USANDO O NODE SaveImage ======================
logger.info("Salvando imagem via node SaveImage...")
# 1. Pegue a saída do decode (tensor)
decoded_tensor = decoded[0] # se 'decoded' for um dict/tuple, ajuste conforme preciso
# 2. Instancia o SaveImage
saveimage_node = NODE_CLASS_MAPPINGS["SaveImage"]()
# 3. Usa o método save_images
# 'filename_prefix' é o prefixo do arquivo de saída
result_dict = saveimage_node.save_images(
filename_prefix="FluxRedux", # ou algo dinâmico se preferir
images=decoded_tensor
)
# 4. Normalmente, o node 'save_images' retorna um dicionário contendo:
# {
# 'ui': {
# 'images': [
# {'filename': 'FluxRedux_12345.png', 'subfolder': ''},
# ...
# ]
# },
# ...
# }
# Assim, para pegar o nome do arquivo salvo:
saved_path = os.path.join(output_dir, result_dict["ui"]["images"][0]["filename"])
logger.info(f"Imagem salva em: {saved_path}")
return saved_path
except Exception as e:
logger.error(f"Erro ao gerar imagem: {str(e)}")
return None
# 10. Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()