Redux / app.py
nftnik's picture
Update app.py
f8708de verified
raw
history blame
12.7 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import logging
# Configurar logging para debug
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 1. Configuração de Caminhos e Imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 2. Imports do ComfyUI
import folder_paths
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
from comfy import model_management
# 3. Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Configurar caminhos dos modelos e verificar estrutura
MODEL_FOLDERS = ["style_models", "text_encoders", "vae", "unet", "clip_vision"]
for model_folder in MODEL_FOLDERS:
folder_path = os.path.join(models_dir, model_folder)
os.makedirs(folder_path, exist_ok=True)
folder_paths.add_model_folder_path(model_folder, folder_path)
logger.info(f"Pasta de modelo configurada: {model_folder}")
# 4. Diagnóstico CUDA
logger.info(f"Python version: {sys.version}")
logger.info(f"Torch version: {torch.__version__}")
logger.info(f"CUDA disponível: {torch.cuda.is_available()}")
logger.info(f"Quantidade de GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
logger.info(f"GPU atual: {torch.cuda.get_device_name(0)}")
# 5. Inicialização do ComfyUI
logger.info("Inicializando ComfyUI...")
try:
init_extra_nodes()
except Exception as e:
logger.warning(f"Aviso na inicialização de nós extras: {str(e)}")
logger.info("Continuando mesmo com avisos nos nós extras...")
# 6. Helper Functions
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def verify_file_exists(folder: str, filename: str) -> bool:
file_path = os.path.join(models_dir, folder, filename)
exists = os.path.exists(file_path)
if not exists:
logger.error(f"Arquivo não encontrado: {file_path}")
return exists
# 7. Download de Modelos
logger.info("Baixando modelos necessários...")
try:
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Redux-dev",
filename="flux1-redux-dev.safetensors",
local_dir=os.path.join(models_dir, "style_models")
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="zer0int/CLIP-GmP-ViT-L-14",
filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="ae.safetensors",
local_dir=os.path.join(models_dir, "vae")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="flux1-dev.safetensors",
local_dir=os.path.join(models_dir, "unet")
)
hf_hub_download(
repo_id="Comfy-Org/sigclip_vision_384",
filename="sigclip_vision_patch14_384.safetensors",
local_dir=os.path.join(models_dir, "clip_vision")
)
except Exception as e:
logger.error(f"Erro ao baixar modelos: {str(e)}")
raise
# 8. Inicialização dos Modelos
logger.info("Inicializando modelos...")
try:
# Use torch.no_grad() em vez de torch.inference_mode()
# para evitar o erro de version counter.
with torch.no_grad():
# CLIP
logger.info("Carregando CLIP...")
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
type="flux"
)
if CLIP_MODEL is None:
raise ValueError("Falha ao carregar CLIP model")
# CLIP Vision
logger.info("Carregando CLIP Vision...")
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
CLIP_VISION = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
if CLIP_VISION is None:
raise ValueError("Falha ao carregar CLIP Vision model")
# Style Model
logger.info("Carregando Style Model...")
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
if STYLE_MODEL is None:
raise ValueError("Falha ao carregar Style Model")
# VAE
logger.info("Carregando VAE...")
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAE_MODEL = vaeloader.load_vae(
vae_name="ae.safetensors"
)
if VAE_MODEL is None:
raise ValueError("Falha ao carregar VAE model")
# UNET
logger.info("Carregando UNET...")
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="fp8_e4m3fn" # Ajuste a seu hardware, se necessário
)
if UNET_MODEL is None:
raise ValueError("Falha ao carregar UNET model")
logger.info("Carregando modelos na GPU...")
model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
for loader in model_loaders
])
logger.info("Modelos carregados com sucesso")
except Exception as e:
logger.error(f"Erro ao inicializar modelos: {str(e)}")
raise
# 9. Função de Geração
@spaces.GPU
def generate_image(
prompt, input_image, lora_weight, guidance, downsampling_factor,
weight, seed, width, height, batch_size, steps,
progress=gr.Progress(track_tqdm=True)
):
try:
# Aqui também: no_grad() para evitar cálculo de gradientes
with torch.no_grad():
logger.info(f"Iniciando geração com prompt: {prompt}")
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=CLIP_MODEL[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
if loaded_image is None:
raise ValueError("Erro ao carregar a imagem de entrada")
logger.info("Imagem carregada com sucesso")
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=STYLE_MODEL[0],
clip_vision=CLIP_VISION[0],
image=loaded_image[0]
)
# Criar latente vazio
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
logger.info("Iniciando sampling...")
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=UNET_MODEL[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# VAE Decode
logger.info("Decodificando imagem...")
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=VAE_MODEL[0]
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
try:
Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
logger.info(f"Imagem salva em: {temp_path}")
return temp_path
except Exception as e:
logger.error(f"Erro ao salvar imagem: {str(e)}")
return None
except Exception as e:
logger.error(f"Erro ao gerar imagem: {str(e)}")
return None
# 10. Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
# Ajuste caso queira compartilhar publicamente, exemplo: app.launch(server_name="0.0.0.0", share=True)
app.launch()