bcci commited on
Commit
d2fc3ff
Β·
verified Β·
1 Parent(s): 91ae13c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -10,10 +10,10 @@ from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
- from kokoro import KPipeline
14
 
15
  app = FastAPI(
16
- title="Kokoro TTS FastAPI",
17
  middleware=[
18
  Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
  ]
@@ -23,7 +23,8 @@ app = FastAPI(
23
  # Global Pipeline Instance
24
  # ------------------------------------------------------------------------------
25
  # Create one pipeline instance for the entire app.
26
- pipeline = KPipeline(lang_code="a")
 
27
 
28
 
29
  # ------------------------------------------------------------------------------
@@ -47,12 +48,12 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
47
  return header + fmt_chunk + data_chunk_header
48
 
49
 
50
- def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
51
  """
52
- Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
53
  """
54
  # Ensure tensor is on CPU and flatten if necessary.
55
- audio_np = audio_tensor.cpu().numpy()
56
  if audio_np.ndim > 1:
57
  audio_np = audio_np.flatten()
58
  # Scale to int16 range.
@@ -60,9 +61,9 @@ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
60
  return audio_int16.tobytes()
61
 
62
 
63
- def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
64
  """
65
- Convert a torch.FloatTensor to Opus encoded bytes.
66
  Requires the 'opuslib' package: pip install opuslib
67
  """
68
  try:
@@ -70,7 +71,7 @@ def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24
70
  except ImportError:
71
  raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
72
 
73
- audio_np = audio_tensor.cpu().numpy()
74
  if audio_np.ndim > 1:
75
  audio_np = audio_np.flatten()
76
  # Scale to int16 range. Important for opus.
@@ -94,55 +95,51 @@ def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24
94
 
95
 
96
  # ------------------------------------------------------------------------------
97
- # Endpoints
98
  # ------------------------------------------------------------------------------
99
 
100
- @app.get("/tts/streaming", summary="True Streaming TTS")
101
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
102
  """
103
- True Streaming TTS endpoint that returns a continuous audio stream.
104
- It processes text and generates audio token by token (or small chunks as KPipeline yields),
105
- providing a more responsive streaming experience.
106
  Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
107
 
108
  The endpoint first yields a WAV header (with a dummy length) for WAV,
109
- then yields encoded audio data for each token's audio as soon as it is generated.
110
  """
111
  sample_rate = 24000
112
  num_channels = 1
113
  sample_width = 2 # 16-bit PCM
114
 
115
- def audio_generator():
116
  if format.lower() == "wav":
117
- # Yield the WAV header first.
118
  header = generate_wav_header(sample_rate, num_channels, sample_width)
119
  yield header
120
 
121
- try:
122
- results = pipeline(text, voice=voice, speed=speed, split_pattern=None) # split_pattern=None to avoid splitting here, let KPipeline handle
123
- for result in results:
124
- if result.audio is not None:
125
- if format.lower() == "wav":
126
- yield audio_tensor_to_pcm_bytes(result.audio)
127
- elif format.lower() == "opus":
128
- yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate)
129
- else:
130
- raise ValueError(f"Unsupported audio format: {format}")
131
  else:
132
- print("No audio generated for a token/chunk") # Debugging, remove in production if not needed
133
- except Exception as e:
134
- print(f"Error during TTS processing: {e}")
135
- yield b'' # Important: yield empty bytes to keep stream alive, or handle error sound
136
 
137
  media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
138
 
139
  return StreamingResponse(
140
- audio_generator(),
141
  media_type=media_type,
142
  headers={"Cache-Control": "no-cache"},
143
  )
144
 
145
 
 
 
 
 
146
  @app.get("/tts/full", summary="Full TTS")
147
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
148
  """
@@ -186,6 +183,9 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str
186
  raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
187
 
188
 
 
 
 
189
 
190
  @app.get("/", response_class=HTMLResponse)
191
  def index():
@@ -199,10 +199,10 @@ def index():
199
  <!DOCTYPE html>
200
  <html>
201
  <head>
202
- <title>Kokoro TTS Demo</title>
203
  </head>
204
  <body>
205
- <h1>Kokoro TTS Demo</h1>
206
  <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
207
  <label for="voice">Voice:</label>
208
  <input type="text" id="voice" value="af_heart"><br>
 
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
+ from kokoro import StreamKPipeline, KModel, KPipeline # Import StreamKPipeline and KModel
14
 
15
  app = FastAPI(
16
+ title="Kokoro Streaming TTS FastAPI",
17
  middleware=[
18
  Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
  ]
 
23
  # Global Pipeline Instance
24
  # ------------------------------------------------------------------------------
25
  # Create one pipeline instance for the entire app.
26
+ model = KModel() # Initialize KModel
27
+ pipeline = StreamKPipeline(lang_code="a", model=model) # Initialize StreamKPipeline, passing the model
28
 
29
 
30
  # ------------------------------------------------------------------------------
 
48
  return header + fmt_chunk + data_chunk_header
49
 
50
 
51
+ def audio_chunk_to_pcm_bytes(audio_chunk: torch.Tensor) -> bytes:
52
  """
53
+ Convert a torch.FloatTensor audio chunk (values in [-1, 1]) to raw 16-bit PCM bytes.
54
  """
55
  # Ensure tensor is on CPU and flatten if necessary.
56
+ audio_np = audio_chunk.cpu().numpy()
57
  if audio_np.ndim > 1:
58
  audio_np = audio_np.flatten()
59
  # Scale to int16 range.
 
61
  return audio_int16.tobytes()
62
 
63
 
64
+ def audio_chunk_to_opus_bytes(audio_chunk: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
65
  """
66
+ Convert a torch.FloatTensor audio chunk to Opus encoded bytes.
67
  Requires the 'opuslib' package: pip install opuslib
68
  """
69
  try:
 
71
  except ImportError:
72
  raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
73
 
74
+ audio_np = audio_chunk.cpu().numpy()
75
  if audio_np.ndim > 1:
76
  audio_np = audio_np.flatten()
77
  # Scale to int16 range. Important for opus.
 
95
 
96
 
97
  # ------------------------------------------------------------------------------
98
+ # Streaming TTS Endpoint
99
  # ------------------------------------------------------------------------------
100
 
101
+ @app.get("/tts/streaming", summary="Streaming TTS")
102
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
103
  """
104
+ Streaming TTS endpoint that returns a continuous audio stream.
 
 
105
  Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
106
 
107
  The endpoint first yields a WAV header (with a dummy length) for WAV,
108
+ then yields encoded audio data chunks as they are generated.
109
  """
110
  sample_rate = 24000
111
  num_channels = 1
112
  sample_width = 2 # 16-bit PCM
113
 
114
+ def audio_chunk_generator():
115
  if format.lower() == "wav":
116
+ # Yield the WAV header first for PCM WAV format.
117
  header = generate_wav_header(sample_rate, num_channels, sample_width)
118
  yield header
119
 
120
+ # Stream audio chunks from the pipeline.
121
+ for audio_chunk in pipeline(text=text, voice=voice, speed=speed):
122
+ if audio_chunk is not None and audio_chunk.numel() > 0:
123
+ if format.lower() == "wav":
124
+ yield audio_chunk_to_pcm_bytes(audio_chunk)
125
+ elif format.lower() == "opus":
126
+ yield audio_chunk_to_opus_bytes(audio_chunk, sample_rate=sample_rate)
 
 
 
127
  else:
128
+ raise ValueError(f"Unsupported audio format: {format}")
 
 
 
129
 
130
  media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
131
 
132
  return StreamingResponse(
133
+ audio_chunk_generator(),
134
  media_type=media_type,
135
  headers={"Cache-Control": "no-cache"},
136
  )
137
 
138
 
139
+ # ------------------------------------------------------------------------------
140
+ # Full TTS Endpoint (unchanged from your original code)
141
+ # ------------------------------------------------------------------------------
142
+
143
  @app.get("/tts/full", summary="Full TTS")
144
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
145
  """
 
183
  raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
184
 
185
 
186
+ # ------------------------------------------------------------------------------
187
+ # HTML Demo Page Endpoint (unchanged from your original code, but updated to call new streaming endpoint)
188
+ # ------------------------------------------------------------------------------
189
 
190
  @app.get("/", response_class=HTMLResponse)
191
  def index():
 
199
  <!DOCTYPE html>
200
  <html>
201
  <head>
202
+ <title>Kokoro Streaming TTS Demo</title>
203
  </head>
204
  <body>
205
+ <h1>Kokoro Streaming TTS Demo</h1>
206
  <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
207
  <label for="voice">Voice:</label>
208
  <input type="text" id="voice" value="af_heart"><br>