Redux / app.py
nftnik's picture
Update app.py
d7cfe15 verified
raw
history blame
11.5 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import logging
# Configurar logging para debug
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 1. Configuração de Caminhos e Imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 2. Imports do ComfyUI
import folder_paths
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
from comfy import model_management
# 3. Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Configurar caminhos dos modelos e verificar estrutura
MODEL_FOLDERS = ["style_models", "text_encoders", "vae", "unet", "clip_vision"]
for model_folder in MODEL_FOLDERS:
folder_path = os.path.join(models_dir, model_folder)
os.makedirs(folder_path, exist_ok=True)
folder_paths.add_model_folder_path(model_folder, folder_path)
logger.info(f"Pasta de modelo configurada: {model_folder}")
# 4. Diagnóstico CUDA
logger.info(f"Python version: {sys.version}")
logger.info(f"Torch version: {torch.__version__}")
logger.info(f"CUDA disponível: {torch.cuda.is_available()}")
logger.info(f"Quantidade de GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
logger.info(f"GPU atual: {torch.cuda.get_device_name(0)}")
# 5. Inicialização do ComfyUI
logger.info("Inicializando ComfyUI...")
try:
init_extra_nodes()
except Exception as e:
logger.warning(f"Aviso na inicialização de nós extras: {str(e)}")
logger.info("Continuando mesmo com avisos nos nós extras...")
# 6. Helper Functions
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def verify_file_exists(folder: str, filename: str) -> bool:
file_path = os.path.join(models_dir, folder, filename)
exists = os.path.exists(file_path)
if not exists:
logger.error(f"Arquivo não encontrado: {file_path}")
return exists
# 7. Download de Modelos
logger.info("Baixando modelos necessários...")
MODELS_TO_DOWNLOAD = [
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "text_encoders"),
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors", "text_encoders"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "vae"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "unet"),
("Comfy-Org/sigclip_vision_384", "sigclip_vision_patch14_384.safetensors", "clip_vision")
]
for repo_id, filename, folder in MODELS_TO_DOWNLOAD:
try:
logger.info(f"Baixando {filename} para {folder}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.join(models_dir, folder))
if not verify_file_exists(folder, filename):
raise FileNotFoundError(f"Arquivo não encontrado após download: {filename}")
except Exception as e:
logger.error(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
raise
# 8. Inicialização dos Modelos
logger.info("Inicializando modelos...")
try:
with torch.inference_mode():
# CLIP
logger.info("Carregando CLIP...")
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
type="flux"
)
# CLIP Vision
logger.info("Carregando CLIP Vision...")
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
CLIP_VISION = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
# Style Model
logger.info("Carregando Style Model...")
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
# VAE
logger.info("Carregando VAE...")
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAE_MODEL = vaeloader.load_vae(
vae_name="ae.safetensors"
)
# UNET
logger.info("Carregando UNET...")
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="fp8_e4m3fn"
)
logger.info("Carregando modelos na GPU...")
model_loaders = [CLIP_MODEL, VAE_MODEL, STYLE_MODEL, CLIP_VISION, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
for loader in model_loaders
])
except Exception as e:
logger.error(f"Erro ao inicializar modelos: {str(e)}")
raise
# 9. Função de Geração
@spaces.GPU
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps, progress=gr.Progress(track_tqdm=True)):
try:
with torch.inference_mode():
logger.info(f"Iniciando geração com prompt: {prompt}")
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=CLIP_MODEL[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
if loaded_image is None:
raise ValueError("Erro ao carregar a imagem de entrada")
logger.info("Imagem carregada com sucesso")
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=STYLE_MODEL[0],
clip_vision=CLIP_VISION[0],
image=loaded_image[0]
)
# Empty Latent
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
logger.info("Iniciando sampling...")
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=UNET_MODEL[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# VAE Decode
logger.info("Decodificando imagem...")
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=VAE_MODEL[0]
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
try:
Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
logger.info(f"Imagem salva em: {temp_path}")
return temp_path
except Exception as e:
logger.error(f"Erro ao salvar imagem: {str(e)}")
return None
except Exception as e:
logger.error(f"Erro ao gerar imagem: {str(e)}")
return None
# 10. Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()