Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
from huggingface_hub import InferenceClient | |
# Khurram | |
# from fastapi import FastAPI, Query | |
# from pydantic import BaseModel | |
# import uvicorn | |
# from fastapi.responses import JSONResponse | |
################# | |
# Import eSpeak TTS pipeline | |
from tts_cli import ( | |
build_model as build_model_espeak, | |
generate_long_form_tts as generate_long_form_tts_espeak, | |
) | |
# Import OpenPhonemizer TTS pipeline | |
from tts_cli_op import ( | |
build_model as build_model_open, | |
generate_long_form_tts as generate_long_form_tts_open, | |
) | |
from pretrained_models import Kokoro | |
# | |
# --------------------------------------------------------------------- | |
# Path to models and voicepacks | |
# --------------------------------------------------------------------- | |
MODELS_DIR = "pretrained_models/Kokoro" | |
VOICES_DIR = "pretrained_models/Kokoro/voices" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
client = InferenceClient(api_key=HF_TOKEN) | |
# --------------------------------------------------------------------- | |
# List the models (.pth) and voices (.pt) | |
# --------------------------------------------------------------------- | |
def get_models(): | |
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]) | |
def get_voices(): | |
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")]) | |
# --------------------------------------------------------------------- | |
# We'll map engine selection -> (build_model_func, generate_func) | |
# --------------------------------------------------------------------- | |
ENGINES = { | |
"espeak": (build_model_espeak, generate_long_form_tts_espeak), | |
"openphonemizer": (build_model_open, generate_long_form_tts_open), | |
} | |
# --------------------------------------------------------------------- | |
# The main inference function called by Gradio | |
# --------------------------------------------------------------------- | |
def tts_inference(text, engine, model_file, voice_file, speed=1.0): | |
""" | |
text: Input string | |
engine: "espeak" or "openphonemizer" | |
model_file: Selected .pth from the models folder | |
voice_file: Selected .pt from the voices folder | |
speed: Speech speed | |
""" | |
# 0) Get the response of user query from LLAMA | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": text + str('describe in one line only') | |
} #, | |
# { | |
# "type": "image_url", | |
# "image_url": { | |
# "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" | |
# } | |
# } | |
] | |
} | |
] | |
response_from_llama = client.chat.completions.create( | |
model="meta-llama/Llama-3.2-11B-Vision-Instruct", | |
messages=messages, | |
max_tokens=500) | |
# 1) Map engine to the correct build_model + generate_long_form_tts | |
build_fn, gen_fn = ENGINES[engine] | |
# 2) Prepare paths | |
model_path = os.path.join(MODELS_DIR, model_file) | |
voice_path = os.path.join(VOICES_DIR, voice_file) | |
# 3) Decide device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 4) Load model | |
model = build_fn(model_path, device=device) | |
# Set submodules eval | |
for k, subm in model.items(): | |
if hasattr(subm, "eval"): | |
subm.eval() | |
# 5) Load voicepack | |
voicepack = torch.load(voice_path, map_location=device) | |
if hasattr(voicepack, "eval"): | |
voicepack.eval() | |
# 6) Generate TTS | |
audio, phonemes = gen_fn(model, response_from_llama.choices[0].message['content'], voicepack, speed=speed) | |
sr = 22050 # or your actual sample rate | |
return (sr, audio) # Gradio expects (sample_rate, np_array) | |
#------------------------------------------ | |
# FAST API | |
#--------------- | |
# app = FastAPI() | |
# class TTSRequest(BaseModel): | |
# text: str | |
# engine: str | |
# model_file: str | |
# voice_file: str | |
# speed: float = 1.0 | |
# @app.post("/tts") | |
# def generate_tts(request: TTSRequest): | |
# try: | |
# sr, audio = tts_inference( | |
# text="What is Deep SeEK? define in 2 lines", | |
# engine="openphonemizer", | |
# model_file="kokoro-v0_19.pth", | |
# voice_file="af_bella.pt", | |
# speed=1.0 | |
# ) | |
# return JSONResponse(content={ | |
# "sample_rate": sr, | |
# "audio_tensor": audio.tolist() | |
# }) | |
# except Exception as e: | |
# return JSONResponse(content={"error": str(e)}, status_code=500) | |
# if __name__ == "__main__": | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
############################### | |
# --------------------------------------------------------------------- | |
# Build Gradio App | |
# --------------------------------------------------------------------- | |
def create_gradio_app(): | |
model_list = get_models() | |
voice_list = get_voices() | |
css = """ | |
h4 { | |
text-align: center; | |
display:block; | |
} | |
h2 { | |
text-align: center; | |
display:block; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
gr.Markdown("## LLAMA TTS DEMO - API - GRADIO VISUAL") | |
# Row 1: Text input | |
text_input = gr.Textbox( | |
label="Enter your question", | |
value="What is AI?", | |
lines=2, | |
) | |
# Row 2: Engine selection | |
# engine_dropdown = gr.Dropdown( | |
# choices=["espeak", "openphonemizer"], | |
# value="openphonemizer", | |
# label="Phonemizer", | |
# ) | |
# Row 3: Model dropdown | |
# model_dropdown = gr.Dropdown( | |
# choices=model_list, | |
# value=model_list[0] if model_list else None, | |
# label="Model (.pth)", | |
# ) | |
# Row 4: Voice dropdown | |
# voice_dropdown = gr.Dropdown( | |
# choices=voice_list, | |
# value=voice_list[0] if voice_list else None, | |
# label="Voice (.pt)", | |
# ) | |
# Row 5: Speed slider | |
speed_slider = gr.Slider( | |
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed" | |
) | |
# Generate button + audio output | |
generate_btn = gr.Button("Generate") | |
tts_output = gr.Audio(label="TTS Output") | |
# Connect the button to our inference function | |
generate_btn.click( | |
fn=tts_inference, | |
inputs=[ | |
text_input, | |
gr.State("openphonemizer"), #engine_dropdown, | |
gr.State("kokoro-v0_19.pth"), #model_dropdown, | |
gr.State("af_bella.pt"), #voice_dropdown, | |
speed_slider, | |
], | |
outputs=tts_output, | |
) | |
gr.Markdown( | |
"#### LLAMA - TTS" | |
) | |
return demo | |
# --------------------------------------------------------------------- | |
# Main | |
# --------------------------------------------------------------------- | |
if __name__ == "__main__": | |
app = create_gradio_app() | |
app.launch() | |