llama-kokoro / app.py
khurrameycon's picture
Update app.py
a3b1df2 verified
import gradio as gr
import os
import torch
from huggingface_hub import InferenceClient
# Khurram
# from fastapi import FastAPI, Query
# from pydantic import BaseModel
# import uvicorn
# from fastapi.responses import JSONResponse
#################
# Import eSpeak TTS pipeline
from tts_cli import (
build_model as build_model_espeak,
generate_long_form_tts as generate_long_form_tts_espeak,
)
# Import OpenPhonemizer TTS pipeline
from tts_cli_op import (
build_model as build_model_open,
generate_long_form_tts as generate_long_form_tts_open,
)
from pretrained_models import Kokoro
#
# ---------------------------------------------------------------------
# Path to models and voicepacks
# ---------------------------------------------------------------------
MODELS_DIR = "pretrained_models/Kokoro"
VOICES_DIR = "pretrained_models/Kokoro/voices"
HF_TOKEN = os.getenv("HF_TOKEN")
client = InferenceClient(api_key=HF_TOKEN)
# ---------------------------------------------------------------------
# List the models (.pth) and voices (.pt)
# ---------------------------------------------------------------------
def get_models():
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
def get_voices():
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
# ---------------------------------------------------------------------
# We'll map engine selection -> (build_model_func, generate_func)
# ---------------------------------------------------------------------
ENGINES = {
"espeak": (build_model_espeak, generate_long_form_tts_espeak),
"openphonemizer": (build_model_open, generate_long_form_tts_open),
}
# ---------------------------------------------------------------------
# The main inference function called by Gradio
# ---------------------------------------------------------------------
def tts_inference(text, engine, model_file, voice_file, speed=1.0):
"""
text: Input string
engine: "espeak" or "openphonemizer"
model_file: Selected .pth from the models folder
voice_file: Selected .pt from the voices folder
speed: Speech speed
"""
# 0) Get the response of user query from LLAMA
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": text + str('describe in one line only')
} #,
# {
# "type": "image_url",
# "image_url": {
# "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
# }
# }
]
}
]
response_from_llama = client.chat.completions.create(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
max_tokens=500)
# 1) Map engine to the correct build_model + generate_long_form_tts
build_fn, gen_fn = ENGINES[engine]
# 2) Prepare paths
model_path = os.path.join(MODELS_DIR, model_file)
voice_path = os.path.join(VOICES_DIR, voice_file)
# 3) Decide device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 4) Load model
model = build_fn(model_path, device=device)
# Set submodules eval
for k, subm in model.items():
if hasattr(subm, "eval"):
subm.eval()
# 5) Load voicepack
voicepack = torch.load(voice_path, map_location=device)
if hasattr(voicepack, "eval"):
voicepack.eval()
# 6) Generate TTS
audio, phonemes = gen_fn(model, response_from_llama.choices[0].message['content'], voicepack, speed=speed)
sr = 22050 # or your actual sample rate
return (sr, audio) # Gradio expects (sample_rate, np_array)
#------------------------------------------
# FAST API
#---------------
# app = FastAPI()
# class TTSRequest(BaseModel):
# text: str
# engine: str
# model_file: str
# voice_file: str
# speed: float = 1.0
# @app.post("/tts")
# def generate_tts(request: TTSRequest):
# try:
# sr, audio = tts_inference(
# text="What is Deep SeEK? define in 2 lines",
# engine="openphonemizer",
# model_file="kokoro-v0_19.pth",
# voice_file="af_bella.pt",
# speed=1.0
# )
# return JSONResponse(content={
# "sample_rate": sr,
# "audio_tensor": audio.tolist()
# })
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
###############################
# ---------------------------------------------------------------------
# Build Gradio App
# ---------------------------------------------------------------------
def create_gradio_app():
model_list = get_models()
voice_list = get_voices()
css = """
h4 {
text-align: center;
display:block;
}
h2 {
text-align: center;
display:block;
}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
gr.Markdown("## LLAMA TTS DEMO - API - GRADIO VISUAL")
# Row 1: Text input
text_input = gr.Textbox(
label="Enter your question",
value="What is AI?",
lines=2,
)
# Row 2: Engine selection
# engine_dropdown = gr.Dropdown(
# choices=["espeak", "openphonemizer"],
# value="openphonemizer",
# label="Phonemizer",
# )
# Row 3: Model dropdown
# model_dropdown = gr.Dropdown(
# choices=model_list,
# value=model_list[0] if model_list else None,
# label="Model (.pth)",
# )
# Row 4: Voice dropdown
# voice_dropdown = gr.Dropdown(
# choices=voice_list,
# value=voice_list[0] if voice_list else None,
# label="Voice (.pt)",
# )
# Row 5: Speed slider
speed_slider = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
)
# Generate button + audio output
generate_btn = gr.Button("Generate")
tts_output = gr.Audio(label="TTS Output")
# Connect the button to our inference function
generate_btn.click(
fn=tts_inference,
inputs=[
text_input,
gr.State("openphonemizer"), #engine_dropdown,
gr.State("kokoro-v0_19.pth"), #model_dropdown,
gr.State("af_bella.pt"), #voice_dropdown,
speed_slider,
],
outputs=tts_output,
)
gr.Markdown(
"#### LLAMA - TTS"
)
return demo
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
if __name__ == "__main__":
app = create_gradio_app()
app.launch()