Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import sys | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from typing import Union, Sequence, Mapping, Any | |
from comfy import model_management | |
from nodes import NODE_CLASS_MAPPINGS | |
# 1. Configuração de Caminhos e Imports | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
comfyui_path = os.path.join(current_dir, "ComfyUI") | |
sys.path.append(comfyui_path) | |
# 2. Imports do ComfyUI | |
import folder_paths | |
from nodes import init_extra_nodes | |
# 3. Configuração de Diretórios | |
BASE_DIR = os.path.dirname(os.path.realpath(__file__)) | |
output_dir = os.path.join(BASE_DIR, "output") | |
models_dir = os.path.join(BASE_DIR, "models") | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(models_dir, exist_ok=True) | |
folder_paths.set_output_directory(output_dir) | |
# 4. Diagnóstico CUDA | |
print("Python version:", sys.version) | |
print("Torch version:", torch.__version__) | |
print("CUDA disponível:", torch.cuda.is_available()) | |
print("Quantidade de GPUs:", torch.cuda.device_count()) | |
if torch.cuda.is_available(): | |
print("GPU atual:", torch.cuda.get_device_name(0)) | |
# 5. Inicialização do ComfyUI | |
print("Inicializando ComfyUI...") | |
init_extra_nodes() | |
# 6. Helper Functions | |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
try: | |
return obj[index] | |
except KeyError: | |
return obj["result"][index] | |
def find_path(name: str, path: str = None) -> str: | |
if path is None: | |
path = os.getcwd() | |
if name in os.listdir(path): | |
path_name = os.path.join(path, name) | |
print(f"{name} found: {path_name}") | |
return path_name | |
parent_directory = os.path.dirname(path) | |
if parent_directory == path: | |
return None | |
return find_path(name, parent_directory) | |
def add_comfyui_directory_to_sys_path() -> None: | |
comfyui_path = find_path("ComfyUI") | |
if comfyui_path is not None and os.path.isdir(comfyui_path): | |
sys.path.append(comfyui_path) | |
print(f"'{comfyui_path}' added to sys.path") | |
def add_extra_model_paths() -> None: | |
try: | |
from main import load_extra_path_config | |
except ImportError: | |
from utils.extra_config import load_extra_path_config | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths is not None: | |
load_extra_path_config(extra_model_paths) | |
else: | |
print("Could not find the extra_model_paths config file.") | |
# 7. Inicialização de caminhos | |
add_comfyui_directory_toSyspath() | |
add_extra_model_paths() | |
def import_custom_nodes() -> None: | |
import asyncio | |
import execution | |
import server | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
server_instance = server.PromptServer(loop) | |
execution.PromptQueue(server_instance) | |
init_extra_nodes() | |
# 8. Download de Modelos | |
def download_models(): | |
print("Baixando modelos...") | |
models = [ | |
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"), | |
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "text_encoders"), | |
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "text_encoders"), | |
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "vae"), | |
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "diffusion_models"), | |
("google/siglip-so400m-patch14-384", "model.safetensors", "clip_vision") | |
] | |
for repo_id, filename, model_type in models: | |
try: | |
model_dir = os.path.join(models_dir, model_type) | |
os.makedirs(model_dir, exist_ok=True) | |
print(f"Baixando {filename} de {repo_id}...") | |
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir) | |
# Adicionar o diretório ao folder_paths | |
folder_paths.add_model_folder_path(model_type, model_dir) | |
except Exception as e: | |
print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}") | |
continue | |
# 9. Download e Inicialização dos Modelos | |
print("Baixando modelos...") | |
download_models() | |
print("Inicializando modelos...") | |
import_custom_nodes() | |
# Global variables for preloaded models and constants | |
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() | |
CONST_1024 = intconstant.get_value(value=1024) | |
# Load CLIP | |
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() | |
CLIP_MODEL = dualcliploader.load_clip( | |
clip_name1="t5xxl_fp16.safetensors", | |
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", | |
type="flux" | |
) | |
# Load VAE | |
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
VAE_MODEL = vaeloader.load_vae( | |
vae_name="ae.safetensors" | |
) | |
# Load CLIP Vision | |
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() | |
CLIP_VISION_MODEL = clipvisionloader.load_clip( | |
clip_name="model.safetensors" | |
) | |
# Load Style Model | |
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]() | |
STYLE_MODEL = stylemodelloader.load_style_model( | |
style_model_name="flux1-redux-dev.safetensors" | |
) | |
# Initialize samplers | |
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]() | |
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler") | |
# Initialize other nodes | |
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() | |
vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]() | |
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() | |
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() | |
clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() | |
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() | |
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]() | |
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() | |
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() | |
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() | |
samplerCustomAdvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() | |
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]() | |
saveimage = NODE_CLASS_MAPPINGS["SaveImage"]() | |
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAnd Count"]() | |
depthanything_v2 = NODE_CLASS MAPPINGS["DepthAnything_V2"]() | |
cr_text = NODE_CLASS_MAPPINGS["CR Text"]() | |
model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION_MODEL, STYLE_MODEL] | |
model_management.load_models_gpu([ | |
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders | |
]) | |
def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps, progress=gr.Progress(track_tqdm=True)) -> str: | |
with torch.inference_mode(): | |
# Set up CLIP | |
clip_switch = cr_text.text_multiline(text="Flux_BFL_Depth_Redux") | |
# Encode text | |
text_encoded = cliptextencode.encode( | |
text=prompt, | |
clip=get_value_at_index(CLIP_MODEL, 0), | |
) | |
# Process input image | |
loaded_image = loadimage.load_image(image=image=input_image) | |
# Get image size | |
size_info = getimagesizeandcount.getsize( | |
image=get_value_at_index(loaded_image, 0) | |
) | |
# Encode VAE | |
vae_encoded = vaeencode.encode( | |
pixels=get_value_at_index(size_info, 0), | |
vae=get_value_at_index(Vae_model, 0), | |
) | |
# Apply Flux guidance | |
flux guided = flux Guidance.append( | |
guidance=guidance, | |
conditioning=get_valueAtIndex(text_encoded, 0), | |
) | |
# Set up empty latent | |
empty_latent = empty_latentimage.generate( | |
width=width, | |
height=height, | |
batch_size=batch_size | |
) | |
# Set up guidance | |
guided = basicguider.get_guider( | |
model=get_value_at_index(unet_model, 0), | |
conditioning=get_value_at_index(loaded_image, 0) | |
) | |
# Set up scheduler | |
schedule = basicscheduler.get_sigmas( | |
scheduler="simple", | |
steps=steps, | |
denoise=1, | |
model=get_value_atIndex(Unet Model, 0), | |
) | |
# Generate random noise | |
noise = randomnoise.get_noise(noise_seed=seed) | |
# Sample | |
sampled = samplerCustom advanced.sample( | |
noise=get_value_at_index(noise, 0), | |
guider=get_value at Index(guided, 0), | |
sampler=get_value at index(sampler, 0), | |
sigmas=get_value at Index(schedule, 0), | |
latent_image=get_value_atindex(empty_latent, 0) | |
) | |
# Decode VAE | |
decoded = va edecode.decode( | |
samples=get_value_atindex(sampled, 0), | |
vae=get_value_at Index(VAE Model, 0), | |
) | |
# Save image | |
saved = saveimage.save_images( | |
filename_prefix=get_value at index(clip switch, 0), | |
images=getValueAtIndex(decoded, 0), | |
) | |
saved_path = f"output/{saved['ui']['images'][0]['filename']}" | |
return saved_path | |
# Create Gradio interface | |
examples = [ | |
["", "mona.png", 0.5, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20], | |
["a woman looking at a house catching fire on the background", "disaster Girl.png", 0.6, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20], | |
["Istanbul aerial, dramatic photography", "Natasha.png", 0.5, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20], | |
] | |
output_image = gr.Image(label="Generated image") | |
with gr.Blocks() as app: | |
gr.markdown("# FLUX Redux Image generator") | |
with gr.Row(): | |
with gr.column(): | |
prompt_input = gr.Text box( | |
label="Prompt", | |
placeholder="Enter your prompt here...", | |
lines=5 | |
) | |
with gr.row(): | |
with gr.column(): | |
lora_weight = gr.slider( | |
minimum=0, | |
maximum=2, | |
step=0.1, | |
value=0.6, | |
label="LoRA Weight" | |
) | |
guidance = gr.slider( | |
minimum=0, | |
maximum=20, | |
step=0.1, | |
value=3.5, | |
label="Guidance" | |
) | |
downsampling_factor = gr.slider( | |
minimum=0, | |
maximum=8, | |
step=1, | |
value=3, | |
label="Downsampling factor" | |
) | |
weight = gr.slider( | |
minimum=0, | |
maximum=2, | |
step=0.1, | |
value=1.0, | |
label="Model weight" | |
) | |
seed = gr.number( | |
value=random.randint(1, 2**64), | |
label="seed", | |
precision=0 | |
) | |
width = gr.number( | |
value=1024, | |
label="width", | |
precision=0 | |
) | |
height = gr.number( | |
value=1024, | |
label="height", | |
precision=0 | |
) | |
batch_size = gr.number( | |
value=1, | |
label="batch size", | |
precision=0 | |
) | |
steps = gr.number( | |
value=20, | |
label="steps", | |
precision=0 | |
) | |
with gr.column(): | |
input_image = gr Image( | |
label="Input Image", | |
type="filepath" | |
) | |
generate_btn = gr.button("Generate image") | |
with gr.column(): | |
output_image.render() | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
prompt_input, | |
input_image, | |
lora_weight, | |
guidance, | |
downsampling_factor, | |
weight, | |
seed, | |
width, | |
height, | |
batch_size, | |
steps | |
], | |
outputs=[output_image] | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |