Spaces:
Running
on
T4
Running
on
T4
File size: 3,604 Bytes
9c20b4e f136260 dc06293 9c20b4e 8ddd281 d336953 dc06293 d336953 9c20b4e 8fa13bc 8ddd281 9c20b4e 9db5d78 9c20b4e 9db5d78 f17c34f 9c20b4e 9db5d78 f17c34f 9c20b4e d336953 9c20b4e dc06293 8a1ab06 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""
utils.py
Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""
import os
import requests
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav
MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"
client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.getenv("FIREWORKS_API_KEY"),
)
hf_client = Client("mrfakename/MeloTTS")
# download and load all models
preload_models()
def generate_script(system_prompt: str, input_text: str, output_model):
"""Get the dialogue from the LLM."""
# Load as python object
try:
response = call_llm(system_prompt, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
except ValidationError as e:
error_message = f"Failed to parse dialogue JSON: {e}"
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
# Call the LLM again to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
response = call_llm(
system_prompt_with_dialogue, "Please improve the dialogue.", output_model
)
improved_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
return improved_dialogue
def call_llm(system_prompt: str, text: str, dialogue_format):
"""Call the LLM with the given prompt and dialogue format."""
response = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=MODEL_ID,
max_tokens=16_384,
temperature=0.1,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
full_url = f"{JINA_URL}{url}"
response = requests.get(full_url, timeout=60)
return response.text
def generate_podcast_audio(text: str, speaker: str, language: str, use_advanced_audio: bool) -> str:
if use_advanced_audio:
audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")
file_path = f"audio_{language}_{speaker}.mp3"
# save audio to disk
write_wav(file_path, SAMPLE_RATE, audio_array)
return file_path
else:
if speaker == "Guest":
accent = "EN-US" if language == "EN" else language
speed = 0.9
else: # host
accent = "EN-Default" if language == "EN" else language
speed = 1
if language != "EN" and speaker != "Guest":
speed = 1.1
# Generate audio
result = hf_client.predict(
text=text,
language=language,
speaker=accent,
speed=speed,
api_name="/synthesize",
)
return result
|