nftnik commited on
Commit
97e7f7b
·
verified ·
1 Parent(s): ebf6853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -75
app.py CHANGED
@@ -1,86 +1,92 @@
1
  import os
2
  import sys
3
-
4
- # Adicionar o caminho da pasta ComfyUI ao sys.path
5
- current_dir = os.path.dirname(os.path.abspath(__file__))
6
- comfyui_path = os.path.join(current_dir, "ComfyUI")
7
- sys.path.append(comfyui_path)
8
-
9
  import random
10
  import torch
11
  from pathlib import Path
12
  from PIL import Image
13
  import gradio as gr
14
  from huggingface_hub import hf_hub_download
15
- from nodes import NODE_CLASS_MAPPINGS
16
- from comfy import model_management
17
- import folder_paths
18
 
 
 
 
19
  print("CUDA disponível:", torch.cuda.is_available())
20
  print("Quantidade de GPUs:", torch.cuda.device_count())
21
  if torch.cuda.is_available():
22
  print("GPU atual:", torch.cuda.get_device_name(0))
 
 
 
 
 
 
23
 
24
- # Diretório base e de saída
 
 
 
 
 
 
 
 
 
25
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
26
  output_dir = os.path.join(BASE_DIR, "output")
27
  os.makedirs(output_dir, exist_ok=True)
28
  folder_paths.set_output_directory(output_dir)
29
 
30
- # Baixar e carregar os modelos necessários
31
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev",
32
- filename="flux1-redux-dev.safetensors",
33
- local_dir="models/style_models")
34
-
35
- hf_hub_download(repo_id="comfyanonymous/flux_text_encoders",
36
- filename="t5xxl_fp16.safetensors",
37
- local_dir="models/text_encoders")
38
-
39
- hf_hub_download(repo_id="zer0int/CLIP-GmP-ViT-L-14",
40
- filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
41
- local_dir="models/text_encoders")
42
-
43
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
44
- filename="ae.safetensors",
45
- local_dir="models/vae")
46
-
47
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
48
- filename="flux1-dev.safetensors.safetensors",
49
- local_dir="models/diffusion_models")
50
-
51
- hf_hub_download(repo_id="google/siglip-so400m-patch14-384",
52
- filename="model.safetensors",
53
- local_dir="models/clip_vision")
54
-
55
- hf_hub_download(repo_id="nftnik/NFTNIK-FLUX.1-dev-LoRA",
56
- filename="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
57
- local_dir="models/lora")
58
-
59
- # Inicializar os nós e pré-carregar os modelos
60
- intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
61
- dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
62
- dualcliploader_357 = dualcliploader.load_clip(
63
- clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
64
- clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
65
- type="flux",
66
- )
67
- stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
68
- stylemodelloader_441 = stylemodelloader.load_style_model(
69
- style_model_name="models/style_models/flux1-redux-dev.safetensors"
70
- )
71
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
72
- vaeloader_359 = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")
73
-
74
- # Lista de modelos para carregamento na GPU
75
- model_loaders = [dualcliploader_357, vaeloader_359, stylemodelloader_441]
76
- valid_models = [
77
- getattr(loader[0], 'patcher', loader[0])
78
- for loader in model_loaders
79
- if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
80
- ]
81
- model_management.load_models_gpu(valid_models)
82
-
83
- # Função para importar nodes personalizados
84
  def import_custom_nodes():
85
  import asyncio
86
  import execution
@@ -89,13 +95,13 @@ def import_custom_nodes():
89
 
90
  loop = asyncio.new_event_loop()
91
  asyncio.set_event_loop(loop)
92
-
93
  server_instance = server.PromptServer(loop)
94
  execution.PromptQueue(server_instance)
95
  init_extra_nodes()
96
 
97
- # Função principal de geração
98
- def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
 
99
  import_custom_nodes()
100
  try:
101
  with torch.inference_mode():
@@ -103,7 +109,7 @@ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_fact
103
  cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
104
  encoded_text = cliptextencode.encode(
105
  text=prompt,
106
- clip=dualcliploader_357[0]
107
  )
108
 
109
  # Carregar LoRA
@@ -111,23 +117,24 @@ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_fact
111
  lora_model = loraloadermodelonly.load_lora_model_only(
112
  lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
113
  strength_model=lora_weight,
114
- model=stylemodelloader_441[0]
115
  )
116
 
117
- # Processar imagem de entrada
118
  loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
119
  loaded_image = loadimage.load_image(image=input_image)
120
 
121
- # Decodificar e salvar
122
  vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
123
  decoded = vaedecode.decode(
124
- samples=lora_model[0],
125
- vae=vaeloader_359[0]
126
  )
127
 
 
128
  temp_filename = f"Flux_{random.randint(0, 99999)}.png"
129
  temp_path = os.path.join(output_dir, temp_filename)
130
- Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
131
 
132
  return temp_path
133
  except Exception as e:
@@ -154,4 +161,7 @@ with gr.Blocks() as app:
154
  )
155
 
156
  if __name__ == "__main__":
157
- app.launch()
 
 
 
 
1
  import os
2
  import sys
 
 
 
 
 
 
3
  import random
4
  import torch
5
  from pathlib import Path
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
+ import spaces
10
+ from typing import Union, Sequence, Mapping, Any
 
11
 
12
+ # Configuração inicial e diagnóstico CUDA
13
+ print("Python version:", sys.version)
14
+ print("Torch version:", torch.__version__)
15
  print("CUDA disponível:", torch.cuda.is_available())
16
  print("Quantidade de GPUs:", torch.cuda.device_count())
17
  if torch.cuda.is_available():
18
  print("GPU atual:", torch.cuda.get_device_name(0))
19
+ else:
20
+ print("CUDA não está disponível. Verificando por que:")
21
+ try:
22
+ torch.cuda.init()
23
+ except Exception as e:
24
+ print("Erro ao inicializar CUDA:", str(e))
25
 
26
+ # Adicionar o caminho da pasta ComfyUI ao sys.path
27
+ current_dir = os.path.dirname(os.path.abspath(__file__))
28
+ comfyui_path = os.path.join(current_dir, "ComfyUI")
29
+ sys.path.append(comfyui_path)
30
+
31
+ from nodes import NODE_CLASS_MAPPINGS
32
+ from comfy import model_management
33
+ import folder_paths
34
+
35
+ # Configuração de diretórios
36
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
37
  output_dir = os.path.join(BASE_DIR, "output")
38
  os.makedirs(output_dir, exist_ok=True)
39
  folder_paths.set_output_directory(output_dir)
40
 
41
+ # Helper function
42
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
43
+ try:
44
+ return obj[index]
45
+ except KeyError:
46
+ return obj["result"][index]
47
+
48
+ # Baixar modelos necessários
49
+ def download_models():
50
+ models = [
51
+ ("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "models/style_models"),
52
+ ("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "models/text_encoders"),
53
+ ("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "models/text_encoders"),
54
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/vae"),
55
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors.safetensors", "models/diffusion_models"),
56
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/clip_vision"),
57
+ ("nftnik/NFTNIK-FLUX.1-dev-LoRA", "NFTNIK_FLUX.1[dev]_LoRA.safetensors", "models/lora")
58
+ ]
59
+
60
+ for repo_id, filename, local_dir in models:
61
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
62
+
63
+ # Inicializar modelos
64
+ print("Inicializando modelos...")
65
+ with torch.inference_mode():
66
+ # Initialize nodes
67
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
68
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
69
+ dualcliploader_357 = dualcliploader.load_clip(
70
+ clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
71
+ clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
72
+ type="flux",
73
+ )
74
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
75
+ stylemodelloader_441 = stylemodelloader.load_style_model(
76
+ style_model_name="models/style_models/flux1-redux-dev.safetensors"
77
+ )
78
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
79
+ vaeloader_359 = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")
80
+
81
+ # Carregar modelos na GPU
82
+ model_loaders = [dualcliploader_357, vaeloader_359, stylemodelloader_441]
83
+ valid_models = [
84
+ getattr(loader[0], 'patcher', loader[0])
85
+ for loader in model_loaders
86
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
87
+ ]
88
+ model_management.load_models_gpu(valid_models)
89
+
 
 
 
 
 
90
  def import_custom_nodes():
91
  import asyncio
92
  import execution
 
95
 
96
  loop = asyncio.new_event_loop()
97
  asyncio.set_event_loop(loop)
 
98
  server_instance = server.PromptServer(loop)
99
  execution.PromptQueue(server_instance)
100
  init_extra_nodes()
101
 
102
+ @spaces.GPU
103
+ def generate_image(prompt, input_image, lora_weight, progress=gr.Progress(track_tqdm=True)):
104
+ """Função principal de geração com monitoramento de progresso"""
105
  import_custom_nodes()
106
  try:
107
  with torch.inference_mode():
 
109
  cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
110
  encoded_text = cliptextencode.encode(
111
  text=prompt,
112
+ clip=get_value_at_index(dualcliploader_357, 0)
113
  )
114
 
115
  # Carregar LoRA
 
117
  lora_model = loraloadermodelonly.load_lora_model_only(
118
  lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
119
  strength_model=lora_weight,
120
+ model=get_value_at_index(stylemodelloader_441, 0)
121
  )
122
 
123
+ # Processar imagem
124
  loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
125
  loaded_image = loadimage.load_image(image=input_image)
126
 
127
+ # Decodificar
128
  vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
129
  decoded = vaedecode.decode(
130
+ samples=get_value_at_index(lora_model, 0),
131
+ vae=get_value_at_index(vaeloader_359, 0)
132
  )
133
 
134
+ # Salvar imagem
135
  temp_filename = f"Flux_{random.randint(0, 99999)}.png"
136
  temp_path = os.path.join(output_dir, temp_filename)
137
+ Image.fromarray((get_value_at_index(decoded, 0) * 255).astype("uint8")).save(temp_path)
138
 
139
  return temp_path
140
  except Exception as e:
 
161
  )
162
 
163
  if __name__ == "__main__":
164
+ # Download models at startup
165
+ download_models()
166
+ # Launch the app
167
+ app.launch()