import gradio as gr import os import torch from huggingface_hub import InferenceClient # Khurram # from fastapi import FastAPI, Query # from pydantic import BaseModel # import uvicorn # from fastapi.responses import JSONResponse ################# # Import eSpeak TTS pipeline from tts_cli import ( build_model as build_model_espeak, generate_long_form_tts as generate_long_form_tts_espeak, ) # Import OpenPhonemizer TTS pipeline from tts_cli_op import ( build_model as build_model_open, generate_long_form_tts as generate_long_form_tts_open, ) from pretrained_models import Kokoro # # --------------------------------------------------------------------- # Path to models and voicepacks # --------------------------------------------------------------------- MODELS_DIR = "pretrained_models/Kokoro" VOICES_DIR = "pretrained_models/Kokoro/voices" HF_TOKEN = os.getenv("HF_TOKEN") client = InferenceClient(api_key=HF_TOKEN) # --------------------------------------------------------------------- # List the models (.pth) and voices (.pt) # --------------------------------------------------------------------- def get_models(): return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]) def get_voices(): return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")]) # --------------------------------------------------------------------- # We'll map engine selection -> (build_model_func, generate_func) # --------------------------------------------------------------------- ENGINES = { "espeak": (build_model_espeak, generate_long_form_tts_espeak), "openphonemizer": (build_model_open, generate_long_form_tts_open), } # --------------------------------------------------------------------- # The main inference function called by Gradio # --------------------------------------------------------------------- def tts_inference(text, engine, model_file, voice_file, speed=1.0): """ text: Input string engine: "espeak" or "openphonemizer" model_file: Selected .pth from the models folder voice_file: Selected .pt from the voices folder speed: Speech speed """ # 0) Get the response of user query from LLAMA messages = [ { "role": "user", "content": [ { "type": "text", "text": text + str('describe in one line only') } #, # { # "type": "image_url", # "image_url": { # "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" # } # } ] } ] response_from_llama = client.chat.completions.create( model="meta-llama/Llama-3.2-11B-Vision-Instruct", messages=messages, max_tokens=500) # 1) Map engine to the correct build_model + generate_long_form_tts build_fn, gen_fn = ENGINES[engine] # 2) Prepare paths model_path = os.path.join(MODELS_DIR, model_file) voice_path = os.path.join(VOICES_DIR, voice_file) # 3) Decide device device = "cuda" if torch.cuda.is_available() else "cpu" # 4) Load model model = build_fn(model_path, device=device) # Set submodules eval for k, subm in model.items(): if hasattr(subm, "eval"): subm.eval() # 5) Load voicepack voicepack = torch.load(voice_path, map_location=device) if hasattr(voicepack, "eval"): voicepack.eval() # 6) Generate TTS audio, phonemes = gen_fn(model, response_from_llama.choices[0].message['content'], voicepack, speed=speed) sr = 22050 # or your actual sample rate return (sr, audio) # Gradio expects (sample_rate, np_array) #------------------------------------------ # FAST API #--------------- # app = FastAPI() # class TTSRequest(BaseModel): # text: str # engine: str # model_file: str # voice_file: str # speed: float = 1.0 # @app.post("/tts") # def generate_tts(request: TTSRequest): # try: # sr, audio = tts_inference( # text="What is Deep SeEK? define in 2 lines", # engine="openphonemizer", # model_file="kokoro-v0_19.pth", # voice_file="af_bella.pt", # speed=1.0 # ) # return JSONResponse(content={ # "sample_rate": sr, # "audio_tensor": audio.tolist() # }) # except Exception as e: # return JSONResponse(content={"error": str(e)}, status_code=500) # if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=8000) ############################### # --------------------------------------------------------------------- # Build Gradio App # --------------------------------------------------------------------- def create_gradio_app(): model_list = get_models() voice_list = get_voices() css = """ h4 { text-align: center; display:block; } h2 { text-align: center; display:block; } """ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: gr.Markdown("## LLAMA TTS DEMO - API - GRADIO VISUAL") # Row 1: Text input text_input = gr.Textbox( label="Enter your question", value="What is AI?", lines=2, ) # Row 2: Engine selection # engine_dropdown = gr.Dropdown( # choices=["espeak", "openphonemizer"], # value="openphonemizer", # label="Phonemizer", # ) # Row 3: Model dropdown # model_dropdown = gr.Dropdown( # choices=model_list, # value=model_list[0] if model_list else None, # label="Model (.pth)", # ) # Row 4: Voice dropdown # voice_dropdown = gr.Dropdown( # choices=voice_list, # value=voice_list[0] if voice_list else None, # label="Voice (.pt)", # ) # Row 5: Speed slider speed_slider = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed" ) # Generate button + audio output generate_btn = gr.Button("Generate") tts_output = gr.Audio(label="TTS Output") # Connect the button to our inference function generate_btn.click( fn=tts_inference, inputs=[ text_input, gr.State("openphonemizer"), #engine_dropdown, gr.State("kokoro-v0_19.pth"), #model_dropdown, gr.State("af_bella.pt"), #voice_dropdown, speed_slider, ], outputs=tts_output, ) gr.Markdown( "#### LLAMA - TTS" ) return demo # --------------------------------------------------------------------- # Main # --------------------------------------------------------------------- if __name__ == "__main__": app = create_gradio_app() app.launch()