Redux / app.py
nftnik's picture
Update app.py
d02c7da verified
raw
history blame
9.96 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import folder_paths
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
from comfy import model_management
# Configura莽茫o de diret贸rios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
# Configurar caminhos dos modelos
for model_folder in ["style_models", "text_encoders", "vae", "unet", "clip_vision"]:
folder_path = os.path.join(models_dir, model_folder)
os.makedirs(folder_path, exist_ok=True)
folder_paths.add_model_folder_path(model_folder, folder_path)
# Download dos modelos
print("Baixando modelos necess谩rios...")
hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev",
filename="flux1-redux-dev.safetensors",
local_dir=os.path.join(models_dir, "style_models"))
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir=os.path.join(models_dir, "text_encoders"))
hf_hub_download(repo_id="zer0int/CLIP-GmP-ViT-L-14",
filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
local_dir=os.path.join(models_dir, "text_encoders"))
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
filename="ae.safetensors",
local_dir=os.path.join(models_dir, "vae"))
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
filename="flux1-dev.safetensors",
local_dir=os.path.join(models_dir, "unet"))
hf_hub_download(repo_id="google/siglip-so400m-patch14-384",
filename="model.safetensors",
local_dir=os.path.join(models_dir, "clip_vision"))
# Diagn贸stico CUDA
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA dispon铆vel:", torch.cuda.is_available())
print("Quantidade de GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
print("GPU atual:", torch.cuda.get_device_name(0))
# Inicializar n贸s extras
print("Inicializando ComfyUI...")
init_extra_nodes()
# Helper function
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# Inicializar modelos
print("Inicializando modelos...")
with torch.inference_mode():
# CLIP
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
type="flux"
)
# Style Model
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
# VAE
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAE_MODEL = vaeloader.load_vae(
vae_name="ae.safetensors"
)
# UNET
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="fp8_e4m3fn"
)
# CLIP Vision
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
CLIP_VISION = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
for loader in model_loaders
])
@spaces.GPU
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps, progress=gr.Progress(track_tqdm=True)):
try:
with torch.inference_mode():
# Text Encoding
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=CLIP_MODEL[0]
)
# Load Input Image
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
# Load LoRA
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
lora_model = loraloadermodelonly.load_lora_model_only(
lora_name="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
strength_model=lora_weight,
model=UNET_MODEL[0]
)
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
autocrop_margin=0.1,
conditioning=flux_guidance[0],
style_model=STYLE_MODEL[0],
clip_vision=CLIP_VISION[0],
image=loaded_image[0]
)
# Empty Latent Image
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=lora_model[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# VAE Decode
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=VAE_MODEL[0]
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
return temp_path
except Exception as e:
print(f"Erro ao gerar imagem: {str(e)}")
return None
# Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()