File size: 5,997 Bytes
6aa4d81
079b1b4
6aa4d81
 
 
 
 
 
97e7f7b
 
6aa4d81
97e7f7b
 
 
584121b
 
 
 
 
97e7f7b
 
 
 
 
4eafbf3
 
97e7f7b
 
 
 
6aa4d81
 
 
 
 
4eafbf3
 
 
 
97e7f7b
 
 
 
 
 
 
 
 
7a2be17
97e7f7b
 
 
 
 
 
 
 
 
 
 
4eafbf3
97e7f7b
 
7a2be17
 
 
97e7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aa4d81
 
 
 
 
 
97e7f7b
6aa4d81
 
882618e
6aa4d81
 
 
 
97e7f7b
6aa4d81
 
97e7f7b
6aa4d81
 
 
97e7f7b
6aa4d81
 
97e7f7b
 
6aa4d81
 
97e7f7b
6aa4d81
 
97e7f7b
6aa4d81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882618e
6aa4d81
 
 
 
97e7f7b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any

# Configuração inicial e diagnóstico CUDA
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA disponível:", torch.cuda.is_available())
print("Quantidade de GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("GPU atual:", torch.cuda.get_device_name(0))

# Adicionar o caminho da pasta ComfyUI ao sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)

# Importar ComfyUI components
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
from comfy import model_management
import folder_paths

# Configuração de diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
os.makedirs(output_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)

# Inicializar nós extras
print("Inicializando nós extras...")
init_extra_nodes()

# Helper function
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
    try:
        return obj[index]
    except KeyError:
        return obj["result"][index]

# Baixar modelos necessários
def download_models():
    print("Baixando modelos...")
    models = [
        ("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "models/style_models"),
        ("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "models/text_encoders"),
        ("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "models/text_encoders"),
        ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/vae"),
        ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors.safetensors", "models/diffusion_models"),
        ("google/siglip-so400m-patch14-384", "model.safetensors", "models/clip_vision"),
        ("nftnik/NFTNIK-FLUX.1-dev-LoRA", "NFTNIK_FLUX.1[dev]_LoRA.safetensors", "models/lora")
    ]
    
    for repo_id, filename, local_dir in models:
        os.makedirs(local_dir, exist_ok=True)
        hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)

# Download models antes de inicializar
download_models()

# Inicializar modelos
print("Inicializando modelos...")
with torch.inference_mode():
    # Initialize nodes
    intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
    dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
    dualcliploader_357 = dualcliploader.load_clip(
        clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
        clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
        type="flux",
    )
    stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
    stylemodelloader_441 = stylemodelloader.load_style_model(
        style_model_name="models/style_models/flux1-redux-dev.safetensors"
    )
    vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
    vaeloader_359 = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")

    # Carregar modelos na GPU
    model_loaders = [dualcliploader_357, vaeloader_359, stylemodelloader_441]
    valid_models = [
        getattr(loader[0], 'patcher', loader[0]) 
        for loader in model_loaders
        if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
    ]
    model_management.load_models_gpu(valid_models)

@spaces.GPU
def generate_image(prompt, input_image, lora_weight, progress=gr.Progress(track_tqdm=True)):
    """Função principal de geração com monitoramento de progresso"""
    try:
        with torch.inference_mode():
            # Codificar texto
            cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
            encoded_text = cliptextencode.encode(
                text=prompt,
                clip=get_value_at_index(dualcliploader_357, 0)
            )

            # Carregar LoRA
            loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
            lora_model = loraloadermodelonly.load_lora_model_only(
                lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
                strength_model=lora_weight,
                model=get_value_at_index(stylemodelloader_441, 0)
            )

            # Processar imagem
            loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
            loaded_image = loadimage.load_image(image=input_image)

            # Decodificar
            vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
            decoded = vaedecode.decode(
                samples=get_value_at_index(lora_model, 0),
                vae=get_value_at_index(vaeloader_359, 0)
            )

            # Salvar imagem
            temp_filename = f"Flux_{random.randint(0, 99999)}.png"
            temp_path = os.path.join(output_dir, temp_filename)
            Image.fromarray((get_value_at_index(decoded, 0) * 255).astype("uint8")).save(temp_path)

            return temp_path
    except Exception as e:
        print(f"Erro ao gerar imagem: {str(e)}")
        return None

# Interface Gradio
with gr.Blocks() as app:
    gr.Markdown("# Gerador de Imagens FLUX Redux")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", placeholder="Digite seu prompt aqui...", lines=5)
            input_image = gr.Image(label="Imagem de Entrada", type="filepath")
            lora_weight = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="Peso LoRA")
            generate_btn = gr.Button("Gerar Imagem")

        with gr.Column():
            output_image = gr.Image(label="Imagem Gerada", type="filepath")

    generate_btn.click(
        fn=generate_image,
        inputs=[prompt_input, input_image, lora_weight],
        outputs=[output_image]
    )

if __name__ == "__main__":
    app.launch()