Spaces:
Running
Running
zach
commited on
Commit
·
04e2d2a
1
Parent(s):
ad1ff58
Fix types in app.py
Browse files- pyproject.toml +1 -0
- src/app.py +47 -50
pyproject.toml
CHANGED
|
@@ -40,6 +40,7 @@ ignore = [
|
|
| 40 |
"EM102",
|
| 41 |
"FIX002",
|
| 42 |
"G004",
|
|
|
|
| 43 |
"PLR0913",
|
| 44 |
"PLR2004",
|
| 45 |
"TD002",
|
|
|
|
| 40 |
"EM102",
|
| 41 |
"FIX002",
|
| 42 |
"G004",
|
| 43 |
+
"PLR0912",
|
| 44 |
"PLR0913",
|
| 45 |
"PLR2004",
|
| 46 |
"TD002",
|
src/app.py
CHANGED
|
@@ -11,7 +11,7 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
|
|
| 11 |
# Standard Library Imports
|
| 12 |
import time
|
| 13 |
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
-
from typing import Tuple
|
| 15 |
|
| 16 |
# Third-Party Library Imports
|
| 17 |
import gradio as gr
|
|
@@ -19,7 +19,7 @@ import gradio as gr
|
|
| 19 |
# Local Application Imports
|
| 20 |
from src import constants
|
| 21 |
from src.config import Config, logger
|
| 22 |
-
from src.custom_types import
|
| 23 |
from src.database.database import DBSessionMaker
|
| 24 |
from src.integrations import (
|
| 25 |
AnthropicError,
|
|
@@ -50,7 +50,7 @@ class App:
|
|
| 50 |
def _generate_text(
|
| 51 |
self,
|
| 52 |
character_description: str,
|
| 53 |
-
) -> Tuple[
|
| 54 |
"""
|
| 55 |
Validates the character_description and generates text using Anthropic API.
|
| 56 |
|
|
@@ -59,13 +59,12 @@ class App:
|
|
| 59 |
|
| 60 |
Returns:
|
| 61 |
Tuple containing:
|
| 62 |
-
- The generated text (as a gr.update).
|
| 63 |
-
-
|
| 64 |
|
| 65 |
Raises:
|
| 66 |
gr.Error: On validation or API errors.
|
| 67 |
"""
|
| 68 |
-
|
| 69 |
try:
|
| 70 |
validate_character_description_length(character_description)
|
| 71 |
except ValueError as ve:
|
|
@@ -88,7 +87,7 @@ class App:
|
|
| 88 |
character_description: str,
|
| 89 |
text: str,
|
| 90 |
generated_text_state: str,
|
| 91 |
-
) -> Tuple[
|
| 92 |
"""
|
| 93 |
Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
|
| 94 |
|
|
@@ -98,7 +97,7 @@ class App:
|
|
| 98 |
- Synthesize two Hume outputs (50% chance).
|
| 99 |
|
| 100 |
The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
|
| 101 |
-
Additional metadata such as the
|
| 102 |
|
| 103 |
Args:
|
| 104 |
character_description (str): The description of the character used for generating the voice.
|
|
@@ -108,13 +107,9 @@ class App:
|
|
| 108 |
|
| 109 |
Returns:
|
| 110 |
Tuple containing:
|
| 111 |
-
-
|
| 112 |
-
-
|
| 113 |
-
-
|
| 114 |
-
- str: The raw audio value (relative file path) for option B.
|
| 115 |
-
- ComparisonType: The comparison type between the selected TTS providers.
|
| 116 |
-
- str: Generation ID for option A.
|
| 117 |
-
- str: Generation ID for option B.
|
| 118 |
- bool: Flag indicating whether the text was modified.
|
| 119 |
- str: The original text that was synthesized.
|
| 120 |
- str: The original character description.
|
|
@@ -122,7 +117,6 @@ class App:
|
|
| 122 |
Raises:
|
| 123 |
gr.Error: If any API or unexpected errors occur during the TTS synthesis process.
|
| 124 |
"""
|
| 125 |
-
|
| 126 |
if not text:
|
| 127 |
logger.warning("Skipping text-to-speech due to empty text.")
|
| 128 |
raise gr.Error("Please generate or enter text to synthesize.")
|
|
@@ -134,34 +128,41 @@ class App:
|
|
| 134 |
try:
|
| 135 |
if provider_b == constants.HUME_AI:
|
| 136 |
num_generations = 2
|
| 137 |
-
# If generating 2 Hume outputs, do so in a single API call
|
| 138 |
-
(
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
) = text_to_speech_with_hume(character_description, text, num_generations, self.config)
|
| 144 |
else:
|
| 145 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 146 |
num_generations = 1
|
| 147 |
-
# Generate a single Hume output
|
| 148 |
future_audio_a = executor.submit(
|
| 149 |
text_to_speech_with_hume, character_description, text, num_generations, self.config
|
| 150 |
)
|
| 151 |
-
# Generate a second TTS output from the second provider
|
| 152 |
match provider_b:
|
| 153 |
case constants.ELEVENLABS:
|
| 154 |
future_audio_b = executor.submit(
|
| 155 |
text_to_speech_with_elevenlabs, character_description, text, self.config
|
| 156 |
)
|
| 157 |
case _:
|
| 158 |
-
# Additional TTS Providers can be added here
|
| 159 |
raise ValueError(f"Unsupported provider: {provider_b}")
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
|
| 166 |
option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
|
| 167 |
options_map: OptionMap = create_shuffled_tts_options(option_a, option_b)
|
|
@@ -185,7 +186,7 @@ class App:
|
|
| 185 |
raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f"Unexpected error during TTS generation: {e}")
|
| 188 |
-
raise gr.Error("An unexpected error
|
| 189 |
|
| 190 |
def _vote(
|
| 191 |
self,
|
|
@@ -195,7 +196,7 @@ class App:
|
|
| 195 |
text_modified: bool,
|
| 196 |
character_description: str,
|
| 197 |
text: str,
|
| 198 |
-
) -> Tuple[bool,
|
| 199 |
"""
|
| 200 |
Handles user voting.
|
| 201 |
|
|
@@ -207,16 +208,15 @@ class App:
|
|
| 207 |
'Option A': 'Hume AI',
|
| 208 |
'Option B': 'ElevenLabs',
|
| 209 |
}
|
| 210 |
-
|
| 211 |
|
| 212 |
Returns:
|
| 213 |
A tuple of:
|
| 214 |
- A boolean indicating if the vote was accepted.
|
| 215 |
-
-
|
| 216 |
-
-
|
| 217 |
-
-
|
| 218 |
"""
|
| 219 |
-
|
| 220 |
if not option_map or vote_submitted:
|
| 221 |
return gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 222 |
|
|
@@ -224,7 +224,7 @@ class App:
|
|
| 224 |
selected_provider = option_map[selected_option]["provider"]
|
| 225 |
other_provider = option_map[other_option]["provider"]
|
| 226 |
|
| 227 |
-
# Report voting results to be persisted to results DB
|
| 228 |
submit_voting_results(
|
| 229 |
option_map,
|
| 230 |
selected_option,
|
|
@@ -254,7 +254,7 @@ class App:
|
|
| 254 |
gr.update(interactive=True),
|
| 255 |
)
|
| 256 |
|
| 257 |
-
def _reset_ui(self) -> Tuple[
|
| 258 |
"""
|
| 259 |
Resets UI state before generating new text.
|
| 260 |
|
|
@@ -263,17 +263,20 @@ class App:
|
|
| 263 |
- option_a_audio_player (clear audio)
|
| 264 |
- option_b_audio_player (clear audio)
|
| 265 |
- vote_button_a (disable and reset button text)
|
| 266 |
-
-
|
| 267 |
- option_map_state (reset option map state)
|
| 268 |
- vote_submitted_state (reset submitted vote state)
|
| 269 |
"""
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
| 271 |
return (
|
| 272 |
gr.update(value=None),
|
| 273 |
gr.update(value=None, autoplay=False),
|
| 274 |
gr.update(value=constants.SELECT_OPTION_A, variant="secondary"),
|
| 275 |
gr.update(value=constants.SELECT_OPTION_B, variant="secondary"),
|
| 276 |
-
|
| 277 |
False,
|
| 278 |
)
|
| 279 |
|
|
@@ -282,7 +285,6 @@ class App:
|
|
| 282 |
Builds the input section including the sample character description dropdown, character
|
| 283 |
description input, and generate text button.
|
| 284 |
"""
|
| 285 |
-
|
| 286 |
sample_character_description_dropdown = gr.Dropdown(
|
| 287 |
choices=list(constants.SAMPLE_CHARACTER_DESCRIPTIONS.keys()),
|
| 288 |
label="Choose a sample character description",
|
|
@@ -308,7 +310,6 @@ class App:
|
|
| 308 |
"""
|
| 309 |
Builds the output section including text input, audio players, and vote buttons.
|
| 310 |
"""
|
| 311 |
-
|
| 312 |
text_input = gr.Textbox(
|
| 313 |
label="Input Text",
|
| 314 |
placeholder="Enter or generate text for synthesis...",
|
|
@@ -342,7 +343,6 @@ class App:
|
|
| 342 |
Returns:
|
| 343 |
gr.Blocks: The fully constructed Gradio UI layout.
|
| 344 |
"""
|
| 345 |
-
|
| 346 |
custom_theme = CustomTheme()
|
| 347 |
with gr.Blocks(
|
| 348 |
title="Expressive TTS Arena",
|
|
@@ -384,7 +384,6 @@ class App:
|
|
| 384 |
) = self._build_output_section()
|
| 385 |
|
| 386 |
# --- UI state components ---
|
| 387 |
-
|
| 388 |
# Track character description used for text and voice generation
|
| 389 |
character_description_state = gr.State("")
|
| 390 |
# Track text used for speech synthesis
|
|
@@ -393,10 +392,8 @@ class App:
|
|
| 393 |
generated_text_state = gr.State("")
|
| 394 |
# Track whether text that was used was generated or modified/custom
|
| 395 |
text_modified_state = gr.State()
|
| 396 |
-
|
| 397 |
# Track option map (option A and option B are randomized)
|
| 398 |
-
option_map_state = gr.State()
|
| 399 |
-
|
| 400 |
# Track whether the user has voted for an option
|
| 401 |
vote_submitted_state = gr.State(False)
|
| 402 |
|
|
@@ -506,7 +503,7 @@ class App:
|
|
| 506 |
inputs=[],
|
| 507 |
outputs=[vote_button_a, vote_button_b],
|
| 508 |
).then(
|
| 509 |
-
fn=self.
|
| 510 |
inputs=[
|
| 511 |
vote_submitted_state,
|
| 512 |
option_map_state,
|
|
|
|
| 11 |
# Standard Library Imports
|
| 12 |
import time
|
| 13 |
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
from typing import Tuple
|
| 15 |
|
| 16 |
# Third-Party Library Imports
|
| 17 |
import gradio as gr
|
|
|
|
| 19 |
# Local Application Imports
|
| 20 |
from src import constants
|
| 21 |
from src.config import Config, logger
|
| 22 |
+
from src.custom_types import Option, OptionMap
|
| 23 |
from src.database.database import DBSessionMaker
|
| 24 |
from src.integrations import (
|
| 25 |
AnthropicError,
|
|
|
|
| 50 |
def _generate_text(
|
| 51 |
self,
|
| 52 |
character_description: str,
|
| 53 |
+
) -> Tuple[dict, str]:
|
| 54 |
"""
|
| 55 |
Validates the character_description and generates text using Anthropic API.
|
| 56 |
|
|
|
|
| 59 |
|
| 60 |
Returns:
|
| 61 |
Tuple containing:
|
| 62 |
+
- The generated text update (as a dict from gr.update).
|
| 63 |
+
- The generated text string.
|
| 64 |
|
| 65 |
Raises:
|
| 66 |
gr.Error: On validation or API errors.
|
| 67 |
"""
|
|
|
|
| 68 |
try:
|
| 69 |
validate_character_description_length(character_description)
|
| 70 |
except ValueError as ve:
|
|
|
|
| 87 |
character_description: str,
|
| 88 |
text: str,
|
| 89 |
generated_text_state: str,
|
| 90 |
+
) -> Tuple[dict, dict, OptionMap, bool, str, str]:
|
| 91 |
"""
|
| 92 |
Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
|
| 93 |
|
|
|
|
| 97 |
- Synthesize two Hume outputs (50% chance).
|
| 98 |
|
| 99 |
The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
|
| 100 |
+
Additional metadata such as the comparison type, generation IDs, and state information are also returned.
|
| 101 |
|
| 102 |
Args:
|
| 103 |
character_description (str): The description of the character used for generating the voice.
|
|
|
|
| 107 |
|
| 108 |
Returns:
|
| 109 |
Tuple containing:
|
| 110 |
+
- dict: Update for the first audio player (with autoplay enabled).
|
| 111 |
+
- dict: Update for the second audio player.
|
| 112 |
+
- OptionMap: A mapping of option constants to their corresponding TTS providers.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
- bool: Flag indicating whether the text was modified.
|
| 114 |
- str: The original text that was synthesized.
|
| 115 |
- str: The original character description.
|
|
|
|
| 117 |
Raises:
|
| 118 |
gr.Error: If any API or unexpected errors occur during the TTS synthesis process.
|
| 119 |
"""
|
|
|
|
| 120 |
if not text:
|
| 121 |
logger.warning("Skipping text-to-speech due to empty text.")
|
| 122 |
raise gr.Error("Please generate or enter text to synthesize.")
|
|
|
|
| 128 |
try:
|
| 129 |
if provider_b == constants.HUME_AI:
|
| 130 |
num_generations = 2
|
| 131 |
+
# If generating 2 Hume outputs, do so in a single API call.
|
| 132 |
+
result = text_to_speech_with_hume(character_description, text, num_generations, self.config)
|
| 133 |
+
# Enforce that 4 values are returned.
|
| 134 |
+
if not (isinstance(result, tuple) and len(result) == 4):
|
| 135 |
+
raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
|
| 136 |
+
generation_id_a, audio_a, generation_id_b, audio_b = result
|
|
|
|
| 137 |
else:
|
| 138 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 139 |
num_generations = 1
|
| 140 |
+
# Generate a single Hume output.
|
| 141 |
future_audio_a = executor.submit(
|
| 142 |
text_to_speech_with_hume, character_description, text, num_generations, self.config
|
| 143 |
)
|
| 144 |
+
# Generate a second TTS output from the second provider.
|
| 145 |
match provider_b:
|
| 146 |
case constants.ELEVENLABS:
|
| 147 |
future_audio_b = executor.submit(
|
| 148 |
text_to_speech_with_elevenlabs, character_description, text, self.config
|
| 149 |
)
|
| 150 |
case _:
|
| 151 |
+
# Additional TTS Providers can be added here.
|
| 152 |
raise ValueError(f"Unsupported provider: {provider_b}")
|
| 153 |
|
| 154 |
+
result_a = future_audio_a.result()
|
| 155 |
+
result_b = future_audio_b.result()
|
| 156 |
+
if isinstance(result_a, tuple) and len(result_a) >= 2:
|
| 157 |
+
generation_id_a, audio_a = result_a[0], result_a[1]
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError("Unexpected return from text_to_speech_with_hume")
|
| 160 |
+
if isinstance(result_b, tuple) and len(result_b) >= 2:
|
| 161 |
+
generation_id_b, audio_b = result_b[0], result_b[1] # type: ignore
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError("Unexpected return from text_to_speech_with_elevenlabs")
|
| 164 |
+
|
| 165 |
+
# Shuffle options so that placement of options in the UI will always be random.
|
| 166 |
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
|
| 167 |
option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
|
| 168 |
options_map: OptionMap = create_shuffled_tts_options(option_a, option_b)
|
|
|
|
| 186 |
raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
|
| 187 |
except Exception as e:
|
| 188 |
logger.error(f"Unexpected error during TTS generation: {e}")
|
| 189 |
+
raise gr.Error("An unexpected error occurred. Please try again later.")
|
| 190 |
|
| 191 |
def _vote(
|
| 192 |
self,
|
|
|
|
| 196 |
text_modified: bool,
|
| 197 |
character_description: str,
|
| 198 |
text: str,
|
| 199 |
+
) -> Tuple[bool, dict, dict, dict]:
|
| 200 |
"""
|
| 201 |
Handles user voting.
|
| 202 |
|
|
|
|
| 208 |
'Option A': 'Hume AI',
|
| 209 |
'Option B': 'ElevenLabs',
|
| 210 |
}
|
| 211 |
+
clicked_option_button (str): The button that was clicked.
|
| 212 |
|
| 213 |
Returns:
|
| 214 |
A tuple of:
|
| 215 |
- A boolean indicating if the vote was accepted.
|
| 216 |
+
- A dict update for the selected vote button (showing provider and trophy emoji).
|
| 217 |
+
- A dict update for the unselected vote button (showing provider).
|
| 218 |
+
- A dict update for enabling vote interactions.
|
| 219 |
"""
|
|
|
|
| 220 |
if not option_map or vote_submitted:
|
| 221 |
return gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 222 |
|
|
|
|
| 224 |
selected_provider = option_map[selected_option]["provider"]
|
| 225 |
other_provider = option_map[other_option]["provider"]
|
| 226 |
|
| 227 |
+
# Report voting results to be persisted to results DB.
|
| 228 |
submit_voting_results(
|
| 229 |
option_map,
|
| 230 |
selected_option,
|
|
|
|
| 254 |
gr.update(interactive=True),
|
| 255 |
)
|
| 256 |
|
| 257 |
+
def _reset_ui(self) -> Tuple[dict, dict, dict, dict, OptionMap, bool]:
|
| 258 |
"""
|
| 259 |
Resets UI state before generating new text.
|
| 260 |
|
|
|
|
| 263 |
- option_a_audio_player (clear audio)
|
| 264 |
- option_b_audio_player (clear audio)
|
| 265 |
- vote_button_a (disable and reset button text)
|
| 266 |
+
- vote_button_b (disable and reset button text)
|
| 267 |
- option_map_state (reset option map state)
|
| 268 |
- vote_submitted_state (reset submitted vote state)
|
| 269 |
"""
|
| 270 |
+
default_option_map: OptionMap = {
|
| 271 |
+
"option_a": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
| 272 |
+
"option_b": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
| 273 |
+
}
|
| 274 |
return (
|
| 275 |
gr.update(value=None),
|
| 276 |
gr.update(value=None, autoplay=False),
|
| 277 |
gr.update(value=constants.SELECT_OPTION_A, variant="secondary"),
|
| 278 |
gr.update(value=constants.SELECT_OPTION_B, variant="secondary"),
|
| 279 |
+
default_option_map, # Reset option_map_state as a default OptionMap
|
| 280 |
False,
|
| 281 |
)
|
| 282 |
|
|
|
|
| 285 |
Builds the input section including the sample character description dropdown, character
|
| 286 |
description input, and generate text button.
|
| 287 |
"""
|
|
|
|
| 288 |
sample_character_description_dropdown = gr.Dropdown(
|
| 289 |
choices=list(constants.SAMPLE_CHARACTER_DESCRIPTIONS.keys()),
|
| 290 |
label="Choose a sample character description",
|
|
|
|
| 310 |
"""
|
| 311 |
Builds the output section including text input, audio players, and vote buttons.
|
| 312 |
"""
|
|
|
|
| 313 |
text_input = gr.Textbox(
|
| 314 |
label="Input Text",
|
| 315 |
placeholder="Enter or generate text for synthesis...",
|
|
|
|
| 343 |
Returns:
|
| 344 |
gr.Blocks: The fully constructed Gradio UI layout.
|
| 345 |
"""
|
|
|
|
| 346 |
custom_theme = CustomTheme()
|
| 347 |
with gr.Blocks(
|
| 348 |
title="Expressive TTS Arena",
|
|
|
|
| 384 |
) = self._build_output_section()
|
| 385 |
|
| 386 |
# --- UI state components ---
|
|
|
|
| 387 |
# Track character description used for text and voice generation
|
| 388 |
character_description_state = gr.State("")
|
| 389 |
# Track text used for speech synthesis
|
|
|
|
| 392 |
generated_text_state = gr.State("")
|
| 393 |
# Track whether text that was used was generated or modified/custom
|
| 394 |
text_modified_state = gr.State()
|
|
|
|
| 395 |
# Track option map (option A and option B are randomized)
|
| 396 |
+
option_map_state = gr.State({}) # OptionMap state as a dictionary
|
|
|
|
| 397 |
# Track whether the user has voted for an option
|
| 398 |
vote_submitted_state = gr.State(False)
|
| 399 |
|
|
|
|
| 503 |
inputs=[],
|
| 504 |
outputs=[vote_button_a, vote_button_b],
|
| 505 |
).then(
|
| 506 |
+
fn=self._vote,
|
| 507 |
inputs=[
|
| 508 |
vote_submitted_state,
|
| 509 |
option_map_state,
|