sachin
commited on
Commit
·
6efcad4
1
Parent(s):
7222675
test
Browse files- requirements.txt +162 -1
- src/server/main.py +97 -236
requirements.txt
CHANGED
@@ -33,4 +33,165 @@ uvicorn
|
|
33 |
fastapi
|
34 |
pydub
|
35 |
python-multipart
|
36 |
-
hf_xet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
fastapi
|
34 |
pydub
|
35 |
python-multipart
|
36 |
+
hf_xet
|
37 |
+
|
38 |
+
|
39 |
+
accelerate==1.6.0
|
40 |
+
aiofiles==23.2.1
|
41 |
+
aiohappyeyeballs==2.6.1
|
42 |
+
aiohttp==3.11.16
|
43 |
+
aiosignal==1.3.2
|
44 |
+
annotated-types==0.7.0
|
45 |
+
antlr4-python3-runtime==4.9.3
|
46 |
+
anyio==4.9.0
|
47 |
+
async-timeout==5.0.1
|
48 |
+
attrs==25.3.0
|
49 |
+
audioread==3.0.1
|
50 |
+
bitsandbytes==0.45.5
|
51 |
+
boto3==1.37.29
|
52 |
+
botocore==1.37.29
|
53 |
+
cached_path==1.7.1
|
54 |
+
cachetools==5.5.2
|
55 |
+
certifi==2025.1.31
|
56 |
+
cffi==1.17.1
|
57 |
+
charset-normalizer==3.4.1
|
58 |
+
click==8.1.8
|
59 |
+
contourpy==1.3.1
|
60 |
+
cycler==0.12.1
|
61 |
+
datasets==3.5.0
|
62 |
+
decorator==5.2.1
|
63 |
+
dill==0.3.8
|
64 |
+
docker-pycreds==0.4.0
|
65 |
+
einops==0.8.1
|
66 |
+
einx==0.3.0
|
67 |
+
ema-pytorch==0.7.7
|
68 |
+
encodec==0.1.1
|
69 |
+
exceptiongroup==1.2.2
|
70 |
+
f5-tts==1.1.0
|
71 |
+
fastapi==0.115.12
|
72 |
+
ffmpy==0.5.0
|
73 |
+
filelock==3.18.0
|
74 |
+
fonttools==4.57.0
|
75 |
+
frozendict==2.4.6
|
76 |
+
frozenlist==1.5.0
|
77 |
+
fsspec==2024.12.0
|
78 |
+
gitdb==4.0.12
|
79 |
+
GitPython==3.1.44
|
80 |
+
google-api-core==2.24.2
|
81 |
+
google-auth==2.38.0
|
82 |
+
google-cloud-core==2.4.3
|
83 |
+
google-cloud-storage==2.19.0
|
84 |
+
google-crc32c==1.7.1
|
85 |
+
google-resumable-media==2.7.2
|
86 |
+
googleapis-common-protos==1.69.2
|
87 |
+
gradio==5.14.0
|
88 |
+
gradio_client==1.7.0
|
89 |
+
groovy==0.1.2
|
90 |
+
h11==0.14.0
|
91 |
+
httpcore==1.0.7
|
92 |
+
httpx==0.28.1
|
93 |
+
huggingface-hub==0.27.1
|
94 |
+
hydra-core==1.3.2
|
95 |
+
idna==3.10
|
96 |
+
jieba==0.42.1
|
97 |
+
Jinja2==3.1.6
|
98 |
+
jmespath==1.0.1
|
99 |
+
joblib==1.4.2
|
100 |
+
kiwisolver==1.4.8
|
101 |
+
lazy_loader==0.4
|
102 |
+
librosa==0.11.0
|
103 |
+
llvmlite==0.44.0
|
104 |
+
loguru==0.7.3
|
105 |
+
markdown-it-py==3.0.0
|
106 |
+
MarkupSafe==2.1.5
|
107 |
+
matplotlib==3.10.1
|
108 |
+
mdurl==0.1.2
|
109 |
+
mpmath==1.3.0
|
110 |
+
msgpack==1.1.0
|
111 |
+
multidict==6.3.2
|
112 |
+
multiprocess==0.70.16
|
113 |
+
networkx==3.4.2
|
114 |
+
numba==0.61.0
|
115 |
+
numpy==1.26.4
|
116 |
+
nvidia-cublas-cu12==12.4.5.8
|
117 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
118 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
119 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
120 |
+
nvidia-cudnn-cu12==9.1.0.70
|
121 |
+
nvidia-cufft-cu12==11.2.1.3
|
122 |
+
nvidia-curand-cu12==10.3.5.147
|
123 |
+
nvidia-cusolver-cu12==11.6.1.9
|
124 |
+
nvidia-cusparse-cu12==12.3.1.170
|
125 |
+
nvidia-cusparselt-cu12==0.6.2
|
126 |
+
nvidia-nccl-cu12==2.21.5
|
127 |
+
nvidia-nvjitlink-cu12==12.4.127
|
128 |
+
nvidia-nvtx-cu12==12.4.127
|
129 |
+
omegaconf==2.3.0
|
130 |
+
orjson==3.10.16
|
131 |
+
packaging==24.2
|
132 |
+
pandas==2.2.3
|
133 |
+
pillow==11.1.0
|
134 |
+
platformdirs==4.3.7
|
135 |
+
pooch==1.8.2
|
136 |
+
propcache==0.3.1
|
137 |
+
proto-plus==1.26.1
|
138 |
+
protobuf==5.29.4
|
139 |
+
psutil==7.0.0
|
140 |
+
pyarrow==19.0.1
|
141 |
+
pyasn1==0.6.1
|
142 |
+
pyasn1_modules==0.4.2
|
143 |
+
pycparser==2.22
|
144 |
+
pydantic==2.10.6
|
145 |
+
pydantic_core==2.27.2
|
146 |
+
pydub==0.25.1
|
147 |
+
Pygments==2.19.1
|
148 |
+
pyparsing==3.2.3
|
149 |
+
pypinyin==0.54.0
|
150 |
+
python-dateutil==2.9.0.post0
|
151 |
+
python-multipart==0.0.20
|
152 |
+
pytz==2025.2
|
153 |
+
PyYAML==6.0.2
|
154 |
+
regex==2024.11.6
|
155 |
+
requests==2.32.3
|
156 |
+
rich==13.9.4
|
157 |
+
rsa==4.9
|
158 |
+
ruff==0.11.4
|
159 |
+
s3transfer==0.11.4
|
160 |
+
safehttpx==0.1.6
|
161 |
+
safetensors==0.5.3
|
162 |
+
scikit-learn==1.6.1
|
163 |
+
scipy==1.15.2
|
164 |
+
semantic-version==2.10.0
|
165 |
+
sentry-sdk==2.25.1
|
166 |
+
setproctitle==1.3.5
|
167 |
+
shellingham==1.5.4
|
168 |
+
six==1.17.0
|
169 |
+
smmap==5.0.2
|
170 |
+
sniffio==1.3.1
|
171 |
+
soundfile==0.13.1
|
172 |
+
soxr==0.5.0.post1
|
173 |
+
starlette==0.46.1
|
174 |
+
sympy==1.13.1
|
175 |
+
threadpoolctl==3.6.0
|
176 |
+
tokenizers==0.21.1
|
177 |
+
tomli==2.2.1
|
178 |
+
tomlkit==0.13.2
|
179 |
+
torch==2.6.0
|
180 |
+
torchaudio==2.6.0
|
181 |
+
torchdiffeq==0.2.5
|
182 |
+
tqdm==4.67.1
|
183 |
+
transformers==4.50.3
|
184 |
+
transformers-stream-generator==0.0.5
|
185 |
+
triton==3.2.0
|
186 |
+
typer==0.15.2
|
187 |
+
typing-inspection==0.4.0
|
188 |
+
typing_extensions==4.13.1
|
189 |
+
tzdata==2025.2
|
190 |
+
urllib3==2.3.0
|
191 |
+
uvicorn==0.34.0
|
192 |
+
vocos==0.1.0
|
193 |
+
wandb==0.19.9
|
194 |
+
websockets==14.2
|
195 |
+
x-transformers==2.2.8
|
196 |
+
xxhash==3.5.0
|
197 |
+
yarl==1.19.0
|
src/server/main.py
CHANGED
@@ -30,7 +30,6 @@ import zipfile
|
|
30 |
import soundfile as sf
|
31 |
import torch
|
32 |
from fastapi import Body, FastAPI, HTTPException, Response
|
33 |
-
from parler_tts import ParlerTTSForConditionalGeneration
|
34 |
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
|
35 |
import numpy as np
|
36 |
from config import SPEED, ResponseFormat, config
|
@@ -64,98 +63,12 @@ if torch.cuda.is_available():
|
|
64 |
else:
|
65 |
print("CUDA is not available on this system.")
|
66 |
|
67 |
-
class TTSModelManager:
|
68 |
-
def __init__(self):
|
69 |
-
self.model_tokenizer: OrderedDict[
|
70 |
-
str, tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]
|
71 |
-
] = OrderedDict()
|
72 |
-
self.max_length = 50
|
73 |
-
|
74 |
-
def load_model(
|
75 |
-
self, model_name: str
|
76 |
-
) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
|
77 |
-
logger.debug(f"Loading {model_name}...")
|
78 |
-
start = time.perf_counter()
|
79 |
-
|
80 |
-
model_name = "ai4bharat/indic-parler-tts"
|
81 |
-
attn_implementation = "flash_attention_2"
|
82 |
-
|
83 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
84 |
-
model_name,
|
85 |
-
attn_implementation=attn_implementation
|
86 |
-
).to(device, dtype=torch_dtype)
|
87 |
-
|
88 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
89 |
-
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
|
90 |
-
|
91 |
-
# Set pad tokens
|
92 |
-
if tokenizer.pad_token is None:
|
93 |
-
tokenizer.pad_token = tokenizer.eos_token
|
94 |
-
if description_tokenizer.pad_token is None:
|
95 |
-
description_tokenizer.pad_token = description_tokenizer.eos_token
|
96 |
-
|
97 |
-
# TODO - temporary disable -torch.compile
|
98 |
-
'''
|
99 |
-
# Update model configuration
|
100 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
101 |
-
# Update for deprecation: use max_batch_size instead of batch_size
|
102 |
-
if hasattr(model.generation_config.cache_config, 'max_batch_size'):
|
103 |
-
model.generation_config.cache_config.max_batch_size = 1
|
104 |
-
model.generation_config.cache_implementation = "static"
|
105 |
-
'''
|
106 |
-
# Compile the model
|
107 |
-
compile_mode = "default"
|
108 |
-
#compile_mode = "reduce-overhead"
|
109 |
-
|
110 |
-
model.forward = torch.compile(model.forward, mode=compile_mode)
|
111 |
-
|
112 |
-
# Warmup
|
113 |
-
warmup_inputs = tokenizer("Warmup text for compilation",
|
114 |
-
return_tensors="pt",
|
115 |
-
padding="max_length",
|
116 |
-
max_length=self.max_length).to(device)
|
117 |
-
|
118 |
-
model_kwargs = {
|
119 |
-
"input_ids": warmup_inputs["input_ids"],
|
120 |
-
"attention_mask": warmup_inputs["attention_mask"],
|
121 |
-
"prompt_input_ids": warmup_inputs["input_ids"],
|
122 |
-
"prompt_attention_mask": warmup_inputs["attention_mask"],
|
123 |
-
}
|
124 |
-
|
125 |
-
n_steps = 1 if compile_mode == "default" else 2
|
126 |
-
for _ in range(n_steps):
|
127 |
-
_ = model.generate(**model_kwargs)
|
128 |
-
|
129 |
-
logger.info(
|
130 |
-
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
131 |
-
)
|
132 |
-
return model, tokenizer, description_tokenizer
|
133 |
-
|
134 |
-
def get_or_load_model(
|
135 |
-
self, model_name: str
|
136 |
-
) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
|
137 |
-
if model_name not in self.model_tokenizer:
|
138 |
-
logger.info(f"Model {model_name} isn't already loaded")
|
139 |
-
if len(self.model_tokenizer) == config.max_models:
|
140 |
-
logger.info("Unloading the oldest loaded model")
|
141 |
-
del self.model_tokenizer[next(iter(self.model_tokenizer))]
|
142 |
-
self.model_tokenizer[model_name] = self.load_model(model_name)
|
143 |
-
return self.model_tokenizer[model_name]
|
144 |
-
|
145 |
-
tts_model_manager = TTSModelManager()
|
146 |
-
|
147 |
-
@asynccontextmanager
|
148 |
-
async def lifespan(_: FastAPI):
|
149 |
-
if not config.lazy_load_model:
|
150 |
-
tts_model_manager.get_or_load_model(config.model)
|
151 |
-
yield
|
152 |
-
|
153 |
app = FastAPI(
|
154 |
title="Dhwani API",
|
155 |
description="AI Chat API supporting Indian languages",
|
156 |
version="1.0.0",
|
157 |
redirect_slashes=False,
|
158 |
-
lifespan=lifespan
|
159 |
)
|
160 |
|
161 |
def chunk_text(text, chunk_size):
|
@@ -165,158 +78,106 @@ def chunk_text(text, chunk_size):
|
|
165 |
chunks.append(' '.join(words[i:i + chunk_size]))
|
166 |
return chunks
|
167 |
|
168 |
-
@app.post("/v1/audio/speech")
|
169 |
-
async def generate_audio(
|
170 |
-
input: Annotated[str, Body()] = config.input,
|
171 |
-
voice: Annotated[str, Body()] = config.voice,
|
172 |
-
model: Annotated[str, Body()] = config.model,
|
173 |
-
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format,
|
174 |
-
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
|
175 |
-
) -> StreamingResponse:
|
176 |
-
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model)
|
177 |
-
if speed != SPEED:
|
178 |
-
logger.warning(
|
179 |
-
"Specifying speed isn't supported by this model. Audio will be generated with the default speed"
|
180 |
-
)
|
181 |
-
start = time.perf_counter()
|
182 |
-
|
183 |
-
chunk_size = 15
|
184 |
-
all_chunks = chunk_text(input, chunk_size)
|
185 |
-
|
186 |
-
if len(all_chunks) <= chunk_size:
|
187 |
-
desc_inputs = description_tokenizer(voice,
|
188 |
-
return_tensors="pt",
|
189 |
-
padding="max_length",
|
190 |
-
max_length=tts_model_manager.max_length).to(device)
|
191 |
-
prompt_inputs = tokenizer(input,
|
192 |
-
return_tensors="pt",
|
193 |
-
padding="max_length",
|
194 |
-
max_length=tts_model_manager.max_length).to(device)
|
195 |
-
|
196 |
-
input_ids = desc_inputs["input_ids"]
|
197 |
-
attention_mask = desc_inputs["attention_mask"]
|
198 |
-
prompt_input_ids = prompt_inputs["input_ids"]
|
199 |
-
prompt_attention_mask = prompt_inputs["attention_mask"]
|
200 |
-
|
201 |
-
generation = tts.generate(
|
202 |
-
input_ids=input_ids,
|
203 |
-
prompt_input_ids=prompt_input_ids,
|
204 |
-
attention_mask=attention_mask,
|
205 |
-
prompt_attention_mask=prompt_attention_mask
|
206 |
-
).to(torch.float32)
|
207 |
-
|
208 |
-
audio_arr = generation.cpu().float().numpy().squeeze()
|
209 |
-
else:
|
210 |
-
all_descriptions = [voice] * len(all_chunks)
|
211 |
-
description_inputs = description_tokenizer(all_descriptions,
|
212 |
-
return_tensors="pt",
|
213 |
-
padding=True).to(device)
|
214 |
-
prompts = tokenizer(all_chunks,
|
215 |
-
return_tensors="pt",
|
216 |
-
padding=True).to(device)
|
217 |
-
|
218 |
-
set_seed(0)
|
219 |
-
generation = tts.generate(
|
220 |
-
input_ids=description_inputs["input_ids"],
|
221 |
-
attention_mask=description_inputs["attention_mask"],
|
222 |
-
prompt_input_ids=prompts["input_ids"],
|
223 |
-
prompt_attention_mask=prompts["attention_mask"],
|
224 |
-
do_sample=True,
|
225 |
-
return_dict_in_generate=True,
|
226 |
-
)
|
227 |
-
|
228 |
-
chunk_audios = []
|
229 |
-
for i, audio in enumerate(generation.sequences):
|
230 |
-
audio_data = audio[:generation.audios_length[i]].cpu().float().numpy().squeeze()
|
231 |
-
chunk_audios.append(audio_data)
|
232 |
-
audio_arr = np.concatenate(chunk_audios)
|
233 |
-
|
234 |
-
device_str = str(device)
|
235 |
-
logger.info(
|
236 |
-
f"Took {time.perf_counter() - start:.2f} seconds to generate audio for {len(input.split())} words using {device_str.upper()}"
|
237 |
-
)
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
317 |
)
|
318 |
|
319 |
-
|
320 |
|
321 |
# Supported language codes
|
322 |
SUPPORTED_LANGUAGES = {
|
|
|
30 |
import soundfile as sf
|
31 |
import torch
|
32 |
from fastapi import Body, FastAPI, HTTPException, Response
|
|
|
33 |
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
|
34 |
import numpy as np
|
35 |
from config import SPEED, ResponseFormat, config
|
|
|
63 |
else:
|
64 |
print("CUDA is not available on this system.")
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
app = FastAPI(
|
67 |
title="Dhwani API",
|
68 |
description="AI Chat API supporting Indian languages",
|
69 |
version="1.0.0",
|
70 |
redirect_slashes=False,
|
71 |
+
#lifespan=lifespan
|
72 |
)
|
73 |
|
74 |
def chunk_text(text, chunk_size):
|
|
|
78 |
chunks.append(' '.join(words[i:i + chunk_size]))
|
79 |
return chunks
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
import io
|
83 |
+
import torch
|
84 |
+
import requests
|
85 |
+
import tempfile
|
86 |
+
import numpy as np
|
87 |
+
import soundfile as sf
|
88 |
+
from fastapi import FastAPI, HTTPException
|
89 |
+
from transformers import AutoModel
|
90 |
+
from pydantic import BaseModel
|
91 |
+
from typing import Optional
|
92 |
+
from starlette.responses import StreamingResponse
|
93 |
+
|
94 |
+
|
95 |
+
tts_repo_id = "ai4bharat/IndicF5"
|
96 |
+
tts_model = AutoModel.from_pretrained(tts_repo_id, trust_remote_code=True)
|
97 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
98 |
+
print("Device:", device)
|
99 |
+
tts_model = tts_model.to(device)
|
100 |
+
|
101 |
+
EXAMPLES = [
|
102 |
+
{
|
103 |
+
"audio_name": "KAN_F (Happy)",
|
104 |
+
"audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
|
105 |
+
"ref_text": "ನಮ್ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
|
106 |
+
"synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
|
107 |
+
},
|
108 |
+
]
|
109 |
+
|
110 |
+
|
111 |
+
# Pydantic model for request body
|
112 |
+
class SynthesizeRequest(BaseModel):
|
113 |
+
text: str # Text to synthesize (expected in Kannada)
|
114 |
+
ref_audio_name: str # Dropdown of audio names from EXAMPLES
|
115 |
+
ref_text: Optional[str] = None # Optional, defaults to example ref_text if not provided
|
116 |
+
|
117 |
+
# Function to load audio from URL
|
118 |
+
def load_audio_from_url(url: str):
|
119 |
+
response = requests.get(url)
|
120 |
+
if response.status_code == 200:
|
121 |
+
audio_data, sample_rate = sf.read(io.BytesIO(response.content))
|
122 |
+
return sample_rate, audio_data
|
123 |
+
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
|
124 |
+
|
125 |
+
# Function to synthesize speech
|
126 |
+
def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
|
127 |
+
# Find the matching example
|
128 |
+
ref_audio_url = None
|
129 |
+
for example in EXAMPLES:
|
130 |
+
if example["audio_name"] == ref_audio_name:
|
131 |
+
ref_audio_url = example["audio_url"]
|
132 |
+
if not ref_text:
|
133 |
+
ref_text = example["ref_text"]
|
134 |
+
break
|
135 |
+
|
136 |
+
if not ref_audio_url:
|
137 |
+
raise HTTPException(status_code=400, detail="Invalid reference audio name.")
|
138 |
+
|
139 |
+
if not text.strip():
|
140 |
+
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
|
141 |
+
|
142 |
+
if not ref_text or not ref_text.strip():
|
143 |
+
raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
|
144 |
+
|
145 |
+
# Load reference audio from URL
|
146 |
+
sample_rate, audio_data = load_audio_from_url(ref_audio_url)
|
147 |
+
|
148 |
+
# Save reference audio to a temporary file
|
149 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
150 |
+
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
|
151 |
+
temp_audio.flush()
|
152 |
|
153 |
+
# Generate speech
|
154 |
+
audio = tts_model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
|
155 |
+
|
156 |
+
# Normalize output
|
157 |
+
if audio.dtype == np.int16:
|
158 |
+
audio = audio.astype(np.float32) / 32768.0
|
159 |
+
|
160 |
+
# Save generated audio to a BytesIO buffer
|
161 |
+
buffer = io.BytesIO()
|
162 |
+
sf.write(buffer, audio, 24000, format='WAV')
|
163 |
+
buffer.seek(0)
|
164 |
+
|
165 |
+
return buffer
|
166 |
+
|
167 |
+
|
168 |
+
@app.post("/v1/audio/speech")
|
169 |
+
async def synthesize(request: SynthesizeRequest):
|
170 |
+
# If ref_text is not provided, it will default to the example's ref_text in the synthesize_speech function
|
171 |
+
audio_buffer = synthesize_speech(request.text, request.ref_audio_name, request.ref_text)
|
172 |
+
|
173 |
+
# Return the audio as a streaming response
|
174 |
+
return StreamingResponse(
|
175 |
+
audio_buffer,
|
176 |
+
media_type="audio/wav",
|
177 |
+
headers={"Content-Disposition": "attachment; filename=synthesized_speech.wav"}
|
178 |
)
|
179 |
|
180 |
+
|
181 |
|
182 |
# Supported language codes
|
183 |
SUPPORTED_LANGUAGES = {
|