Redux / app.py
nftnik's picture
Update app.py
a7d1628 verified
raw
history blame
8.93 kB
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
from folder_paths import folder_paths
# Append ComfyUI path to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)
# Import Comfy modules
from comfy import model_management
from nodes import NODE_CLASS_MAPPINGS
# Helper function to get values from objects
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# Download models from Hugging Face
def download_models():
models = [
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "text_encoders"),
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "text_encoders"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "vae"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "diffusion_models"),
("google/siglip-so400m-patch14-384", "model.safetensors", "clip_vision")
]
for repo_id, filename, model_type in models:
try:
model_dir = os.path.join(models_dir, model_type)
os.makedirs(model_dir, exist_ok=True)
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir)
folder_paths.add_model_folder_path(model_type, model_dir)
except Exception as e:
print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
continue
# Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Download and load models
download_models()
# Load models globally
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_357 = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
type="flux"
)
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
clip_vision = clipvisionloader.load_clip(
clip_name="model.safetensors"
)
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
stylemodelloader_441 = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_359 = vaeloader.load_vae(
vae_name="ae.safetensors"
)
# Pre-load models
model_loaders = [dualcliploader_357, vaeloader_359, clip_vision, stylemodelloader_441]
valid_models = [
getattr(loader[0], 'patcher', loader[0])
for loader in model_loaders
if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
]
model_management.load_models_gpu(valid_models)
# Function to generate images
@spaces.GPU(duration=60) # Adjust duration as needed
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
try:
with torch.inference_mode():
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=dualcliploader_357[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=stylemodelloader_441[0],
clip_vision=clip_vision[0],
image=loaded_image[0]
)
# Empty Latent
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=stylemodelloader_441[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# Decodificar VAE
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=vaeloader_359[0]
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
return temp_path
except Exception as e:
print(f"Erro ao gerar imagem: {str(e)}")
return None
# Gradio Interface
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()