FQiao's picture
Upload 70 files
3324de2 verified
import os
import logging
import json
import random
import torch
import torchaudio
import re
from diffusers import AutoencoderOobleck, FluxTransformer2DModel
from huggingface_hub import snapshot_download
from comfy.utils import load_torch_file, ProgressBar
import folder_paths
from tangoflux.model import TangoFlux
from .teacache import teacache_forward
log = logging.getLogger("TangoFlux")
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
if "tangoflux" not in folder_paths.folder_names_and_paths:
current_paths = [TANGOFLUX_DIR]
else:
current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"]
folder_paths.folder_names_and_paths["tangoflux"] = (
current_paths,
folder_paths.supported_pt_extensions,
)
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
class TangoFluxLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"enable_teacache": ("BOOLEAN", {"default": False}),
"rel_l1_thresh": (
"FLOAT",
{"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01},
),
},
}
RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE")
RETURN_NAMES = ("model", "vae")
OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae")
CATEGORY = "TangoFlux"
FUNCTION = "load_tangoflux"
DESCRIPTION = "Load TangoFlux model"
def __init__(self):
self.model = None
self.vae = None
self.enable_teacache = False
self.rel_l1_thresh = 0.25
self.original_forward = FluxTransformer2DModel.forward
def load_tangoflux(
self,
enable_teacache=False,
rel_l1_thresh=0.25,
tangoflux_path=TANGOFLUX_DIR,
text_encoder_path=TEXT_ENCODER_DIR,
device="cuda",
):
if self.model is None or self.enable_teacache != enable_teacache:
pbar = ProgressBar(6)
snapshot_download(
repo_id="declare-lab/TangoFlux",
allow_patterns=["*.json", "*.safetensors"],
local_dir=tangoflux_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading config")
with open(os.path.join(tangoflux_path, "config.json"), "r") as f:
config = json.load(f)
pbar.update(1)
text_encoder = re.sub(
r'[<>:"/\\|?*]',
"-",
config.get("text_encoder_name", "google/flan-t5-large"),
)
text_encoder_path = os.path.join(text_encoder_path, text_encoder)
snapshot_download(
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
allow_patterns=["*.json", "*.safetensors", "*.model"],
local_dir=text_encoder_path,
local_dir_use_symlinks=False,
)
pbar.update(1)
log.info("Loading TangoFlux models")
del self.model
self.model = None
model_weights = load_torch_file(
os.path.join(tangoflux_path, "tangoflux.safetensors"),
device=torch.device(device),
)
pbar.update(1)
if enable_teacache:
log.info("Enabling TeaCache")
FluxTransformer2DModel.forward = teacache_forward
else:
log.info("Disabling TeaCache")
FluxTransformer2DModel.forward = self.original_forward
model = TangoFlux(config=config, text_encoder_dir=text_encoder_path)
model.load_state_dict(model_weights, strict=False)
model.to(device)
if enable_teacache:
model.transformer.__class__.enable_teacache = True
model.transformer.__class__.cnt = 0
model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
model.transformer.__class__.accumulated_rel_l1_distance = 0
model.transformer.__class__.previous_modulated_input = None
model.transformer.__class__.previous_residual = None
pbar.update(1)
self.model = model
del model
self.enable_teacache = enable_teacache
self.rel_l1_thresh = rel_l1_thresh
if self.vae is None:
log.info("Loading TangoFlux VAE")
vae_weights = load_torch_file(
os.path.join(tangoflux_path, "vae.safetensors")
)
self.vae = AutoencoderOobleck()
self.vae.load_state_dict(vae_weights)
self.vae.to(device)
pbar.update(1)
if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh:
self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
self.rel_l1_thresh = rel_l1_thresh
return (self.model, self.vae)
class TangoFluxSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("TANGOFLUX_MODEL",),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}),
"guidance_scale": (
"FLOAT",
{"default": 3, "min": 1, "max": 100, "step": 1},
),
"duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
}
RETURN_TYPES = ("TANGOFLUX_LATENTS",)
RETURN_NAMES = ("latents",)
OUTPUT_TOOLTIPS = "TangoFlux Sample"
CATEGORY = "TangoFlux"
FUNCTION = "sample"
DESCRIPTION = "Sampler for TangoFlux"
def sample(
self,
model,
prompt,
steps=50,
guidance_scale=3,
duration=10,
seed=0,
batch_size=1,
device="cuda",
):
pbar = ProgressBar(steps)
with torch.no_grad():
model.to(device)
try:
if model.transformer.__class__.enable_teacache:
model.transformer.__class__.num_steps = steps
except:
pass
log.info("Generating latents with TangoFlux")
latents = model.inference_flow(
prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale,
seed=seed,
num_samples_per_prompt=batch_size,
callback_on_step_end=lambda: pbar.update(1),
)
return ({"latents": latents, "duration": duration},)
class TangoFluxVAEDecodeAndPlay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"vae": ("TANGOFLUX_VAE",),
"latents": ("TANGOFLUX_LATENTS",),
"filename_prefix": ("STRING", {"default": "TangoFlux"}),
"format": (
["wav", "mp3", "flac", "aac", "wma"],
{"default": "wav"},
),
"save_output": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
CATEGORY = "TangoFlux"
FUNCTION = "play"
DESCRIPTION = "Decoder and Player for TangoFlux"
def decode(self, vae, latents):
results = []
for latent in latents:
decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu()
results.append(decoded)
results = torch.cat(results, dim=0)
return results
def play(
self,
vae,
latents,
filename_prefix="TangoFlux",
format="wav",
save_output=True,
device="cuda",
):
audios = []
pbar = ProgressBar(len(latents) + 2)
if save_output:
output_dir = folder_paths.get_output_directory()
prefix_append = ""
type = "output"
else:
output_dir = folder_paths.get_temp_directory()
prefix_append = "_temp_" + "".join(
random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)
)
type = "temp"
filename_prefix += prefix_append
full_output_folder, filename, counter, subfolder, _ = (
folder_paths.get_save_image_path(filename_prefix, output_dir)
)
os.makedirs(full_output_folder, exist_ok=True)
pbar.update(1)
duration = latents["duration"]
latents = latents["latents"]
vae.to(device)
log.info("Decoding Tangoflux latents")
waves = self.decode(vae, latents)
pbar.update(1)
for wave in waves:
waveform_end = int(duration * vae.config.sampling_rate)
wave = wave[:, :waveform_end]
file = f"{filename}_{counter:05}_.{format}"
torchaudio.save(
os.path.join(full_output_folder, file), wave, sample_rate=44100
)
counter += 1
audios.append({"filename": file, "subfolder": subfolder, "type": type})
pbar.update(1)
return {
"ui": {"audios": audios},
}
NODE_CLASS_MAPPINGS = {
"TangoFluxLoader": TangoFluxLoader,
"TangoFluxSampler": TangoFluxSampler,
"TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay,
}