Redux / app.py
nftnik's picture
Update app.py
6eb4d49 verified
raw
history blame
9.89 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
# Adicionar o caminho da pasta ComfyUI ao sys.path primeiro
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)
# Agora podemos importar os m贸dulos do ComfyUI
import folder_paths
# Configura莽茫o inicial e diagn贸stico CUDA
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA dispon铆vel:", torch.cuda.is_available())
print("Quantidade de GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
print("GPU atual:", torch.cuda.get_device_name(0))
# Adicionar o caminho da pasta ComfyUI ao sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)
# Inicializar o ComfyUI
def init_comfyui():
import execution
from nodes import NODE_CLASS_MAPPINGS, init_custom_nodes
import server
import asyncio
# Criar e configurar o loop de eventos
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Inicializar servidor e n贸s
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_custom_nodes()
return NODE_CLASS_MAPPINGS
print("Inicializando ComfyUI...")
NODE_CLASS_MAPPINGS = init_comfyui()
# Configura莽茫o de diret贸rios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
os.makedirs(output_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Helper function
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# Baixar modelos
def download_models():
print("Baixando modelos...")
models = [
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "models/style_models"),
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "models/text_encoders"),
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "models/text_encoders"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/vae"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/diffusion_models"),
("google/siglip-so400m-patch14-384", "model.safetensors", "models/clip_vision")
]
for repo_id, filename, local_dir in models:
try:
os.makedirs(local_dir, exist_ok=True)
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
except Exception as e:
print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
continue
# Download models no in铆cio
download_models()
# Inicializar modelos
print("Inicializando modelos...")
with torch.inference_mode():
# CLIP
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_357 = dualcliploader.load_clip(
clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
type="flux",
)
# Style Model
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
stylemodelloader_441 = stylemodelloader.load_style_model(
style_model_name="models/style_models/flux1-redux-dev.safetensors"
)
# VAE
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_359 = vaeloader.load_vae(
vae_name="models/vae/ae.safetensors"
)
@spaces.GPU
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps, progress=gr.Progress(track_tqdm=True)):
try:
with torch.inference_mode():
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=dualcliploader_357[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Carregar LoRA
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
lora_model = loraloadermodelonly.load_lora_model_only(
lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
strength_model=lora_weight,
model=stylemodelloader_441[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=stylemodelloader_441[0],
image=loaded_image[0]
)
# Empty Latent
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=lora_model[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# Decodificar VAE
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=vaeloader_359[0]
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
return temp_path
except Exception as e:
print(f"Erro ao gerar imagem: {str(e)}")
return None
# Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()