sachin
kannada
f25cff8
import argparse
import codecs
import os
import re
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
import pandas as pd
from tqdm import tqdm
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
from f5_tts.model import DiT, UNetT
def run_batch_inference(prompt_paths, prompt_texts, texts, languages, categories, model_obj, vocoder, mel_spec_type, remove_silence, speed, output_dir):
count = 0
for ref_audio in prompt_paths:
if not isinstance(ref_audio, str) or not os.path.isfile(ref_audio):
print(f"Invalid ref_audio: {ref_audio}")
count += 1
print(count)
# raise ValueError(f"Invalid ref_audio: {ref_audio}")
for idx, (ref_audio, ref_text, text_gen, language, category) in tqdm(enumerate(zip(prompt_paths, prompt_texts, texts, languages, categories))):
voices = {"main": {"ref_audio": ref_audio, "ref_text": ref_text}}
for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
)
generated_audio_segments.append(audio)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
filename = f"{language.upper()}_{category.upper()}_{idx}.wav"
outfile_dir = os.path.join(output_dir, language)
os.makedirs(outfile_dir, exist_ok=True)
wave_path = Path(outfile_dir) / filename
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f"Generated audio saved to: {f.name}")
def main():
parser = argparse.ArgumentParser(
prog="python3 infer-cli.py",
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
epilog="Specify options above to override one or more settings from config.",
)
parser.add_argument(
"-m",
"--model",
help="F5-TTS | E2-TTS",
)
parser.add_argument(
"-p",
"--ckpt_file",
help="The Checkpoint .pt",
)
parser.add_argument(
"-v",
"--vocab_file",
help="The vocab .txt",
)
parser.add_argument(
"-f",
"--generate_csv",
type=str,
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="Path to output folder..",
)
parser.add_argument(
"--remove_silence",
help="Remove silence.",
)
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Adjust the speed of the audio generation (default: 1.0)",
)
args = parser.parse_args()
# Read texts and prompts to generate
filepath = args.generate_csv
df = pd.read_csv(filepath)
prompt_paths = df['prompt_path'].tolist()
prompt_texts = df['prompt_text'].tolist()
texts = df['text'].tolist()
languages = df['language'].tolist()
categories = df['category'].tolist()
# Model config
model = args.model
ckpt_file = args.ckpt_file
vocab_file = args.vocab_file
remove_silence = args.remove_silence
speed = args.speed
vocoder_name = args.vocoder_name
mel_spec_type = args.vocoder_name
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
# load models
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
# Batch inference
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
run_batch_inference(prompt_paths, prompt_texts, texts, languages, categories, ema_model, vocoder, mel_spec_type, remove_silence, speed, output_dir)
if __name__ == "__main__":
main()