Pijush2023 commited on
Commit
0465b6f
·
verified ·
1 Parent(s): a4c99d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -336
app.py CHANGED
@@ -722,9 +722,8 @@ def generate_audio_elevenlabs(text):
722
  return None
723
 
724
 
725
- # Parler TTS integration
726
-
727
  repo_id = "parler-tts/parler-tts-mini-v1"
 
728
  parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
729
  parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
730
  parler_feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
@@ -792,200 +791,6 @@ def generate_audio_parler_tts(text):
792
  logging.debug(f"Audio saved to {combined_audio_path}")
793
  return combined_audio_path
794
 
795
- # Streaming Parler-TTS with the Base Streamer
796
-
797
- import io
798
- import math
799
- from queue import Queue
800
- from threading import Thread
801
- from typing import Optional
802
-
803
- from transformers.generation.streamers import BaseStreamer
804
-
805
- class ParlerTTSStreamer(BaseStreamer):
806
- def __init__(
807
- self,
808
- model: ParlerTTSForConditionalGeneration,
809
- device: Optional[str] = None,
810
- play_steps: Optional[int] = 10,
811
- stride: Optional[int] = None,
812
- timeout: Optional[float] = None,
813
- ):
814
- self.decoder = model.decoder
815
- self.audio_encoder = model.audio_encoder
816
- self.generation_config = model.generation_config
817
- self.device = device if device is not None else model.device
818
-
819
- self.play_steps = play_steps
820
- if stride is not None:
821
- self.stride = stride
822
- else:
823
- hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
824
- self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
825
- self.token_cache = None
826
- self.to_yield = 0
827
-
828
- self.audio_queue = Queue()
829
- self.stop_signal = None
830
- self.timeout = timeout
831
-
832
- def apply_delay_pattern_mask(self, input_ids):
833
- _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
834
- input_ids[:, :1],
835
- bos_token_id=self.generation_config.bos_token_id,
836
- pad_token_id=self.generation_config.decoder_start_token_id,
837
- max_length=input_ids.shape[-1],
838
- )
839
- input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
840
-
841
- mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
842
- input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
843
- input_ids = input_ids[None, ...]
844
-
845
- input_ids = input_ids.to(self.audio_encoder.device)
846
-
847
- decode_sequentially = (
848
- self.generation_config.bos_token_id in input_ids
849
- or self.generation_config.pad_token_id in input_ids
850
- or self.generation_config.eos_token_id in input_ids
851
- )
852
- if not decode_sequentially:
853
- output_values = self.audio_encoder.decode(
854
- input_ids,
855
- audio_scales=[None],
856
- )
857
- else:
858
- sample = input_ids[:, 0]
859
- sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
860
- sample = sample[:, :, sample_mask]
861
- output_values = self.audio_encoder.decode(sample[None, ...], [None])
862
-
863
- audio_values = output_values.audio_values[0, 0]
864
- return audio_values.cpu().float().numpy()
865
-
866
- def put(self, value):
867
- batch_size = value.shape[0] // self.decoder.num_codebooks
868
- if batch_size > 1:
869
- raise ValueError("ParlerTTSStreamer only supports batch size 1")
870
-
871
- if self.token_cache is None:
872
- self.token_cache = value
873
- else:
874
- self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
875
-
876
- if self.token_cache.shape[-1] % self.play_steps == 0:
877
- audio_values = self.apply_delay_pattern_mask(self.token_cache)
878
- self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
879
- self.to_yield += len(audio_values) - self.to_yield - self.stride
880
-
881
- def end(self):
882
- if self.token_cache is not None:
883
- audio_values = self.apply_delay_pattern_mask(self.token_cache)
884
- else:
885
- audio_values = np.zeros(self.to_yield)
886
-
887
- self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
888
-
889
- def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
890
- self.audio_queue.put(audio, timeout=self.timeout)
891
- if stream_end:
892
- self.audio_queue.put(self.stop_signal, timeout=self.timeout)
893
-
894
- def __iter__(self):
895
- return self
896
-
897
- def __next__(self):
898
- value = self.audio_queue.get(timeout=self.timeout)
899
- if not isinstance(value, np.ndarray) and value == self.stop_signal:
900
- raise StopIteration()
901
- else:
902
- return value
903
-
904
- def numpy_to_mp3(audio_array, sampling_rate):
905
- if np.issubdtype(audio_array.dtype, np.floating):
906
- max_val = np.max(np.abs(audio_array))
907
- audio_array = (audio_array / max_val) * 32767
908
- audio_array = audio_array.astype(np.int16)
909
-
910
- audio_segment = AudioSegment(
911
- audio_array.tobytes(),
912
- frame_rate=sampling_rate,
913
- sample_width=audio_array.dtype.itemsize,
914
- channels=1
915
- )
916
-
917
- mp3_io = io.BytesIO()
918
- audio_segment.export(mp3_io, format="mp3", bitrate="320k")
919
-
920
- mp3_bytes = mp3_io.getvalue()
921
- mp3_io.close()
922
-
923
- return mp3_bytes
924
-
925
- sampling_rate = model.audio_encoder.config.sampling_rate
926
- frame_rate = model.audio_encoder.config.frame_rate
927
-
928
- def generate_base(text, description, play_steps_in_s=2.0):
929
- play_steps = int(frame_rate * play_steps_in_s)
930
- streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
931
-
932
- inputs = parler_tokenizer(description, return_tensors="pt").to(device)
933
- prompt = parler_tokenizer(text, return_tensors="pt").to(device)
934
-
935
- generation_kwargs = dict(
936
- input_ids=inputs.input_ids,
937
- prompt_input_ids=prompt.input_ids,
938
- streamer=streamer,
939
- do_sample=True,
940
- temperature=1.0,
941
- min_new_tokens=10,
942
- )
943
-
944
- set_seed(SEED)
945
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
946
- thread.start()
947
-
948
- for new_audio in streamer:
949
- print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
950
- yield numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
951
-
952
- css = """
953
- #share-btn-container {
954
- display: flex;
955
- padding-left: 0.5rem !important;
956
- padding-right: 0.5rem !important;
957
- background-color: #000000;
958
- justify-content: center;
959
- align-items: center;
960
- border-radius: 9999px !important;
961
- width: 13rem;
962
- margin-top: 10px;
963
- margin-left: auto;
964
- flex: unset !important;
965
- }
966
- #share-btn {
967
- all: initial;
968
- color: #ffffff;
969
- font-weight: 600;
970
- cursor: pointer;
971
- font-family: 'IBM Plex Sans', sans-serif;
972
- margin-left: 0.5rem !important;
973
- padding-top: 0.25rem !important;
974
- padding-bottom: 0.25rem !important;
975
- right:0;
976
- }
977
- #share-btn * {
978
- all: unset !important;
979
- }
980
- #share-btn-container div:nth-child(-n+2){
981
- width: auto !important;
982
- min-height: 0px !important;
983
- }
984
- #share-btn-container .wrap {
985
- display: none !important;
986
- }
987
- """
988
-
989
 
990
 
991
 
@@ -1325,136 +1130,11 @@ def fetch_google_flights(departure_id="JFK", arrival_id="BHM", outbound_date=cur
1325
 
1326
  return flight_info
1327
 
1328
- # with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1329
- # with gr.Row():
1330
- # with gr.Column():
1331
- # state = gr.State()
1332
-
1333
- # chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1334
- # choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1335
- # retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1336
- # model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1337
-
1338
-
1339
- # # Link the dropdown change to handle_retrieval_mode_change
1340
- # model_choice.change(fn=handle_retrieval_mode_change, inputs=model_choice, outputs=[retrieval_mode, choice])
1341
-
1342
-
1343
- # gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1344
-
1345
- # chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1346
- # tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
1347
- # retriever_button = gr.Button("Retriever")
1348
-
1349
- # clear_button = gr.Button("Clear")
1350
- # clear_button.click(lambda:[None,None], outputs=[chat_input, state])
1351
-
1352
- # gr.Markdown("<h1 style='color: red;'>Radar Map</h1>", elem_id="Map-Radar")
1353
- # location_output = gr.HTML()
1354
- # audio_output = gr.Audio(interactive=False, autoplay=True)
1355
-
1356
- # def stop_audio():
1357
- # audio_output.stop()
1358
- # return None
1359
-
1360
- # retriever_sequence = (
1361
- # retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="Ask_Retriever")
1362
- # .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query")
1363
- # .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response")
1364
- # .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder")
1365
- # .then(fn=clear_textbox, inputs=[], outputs=[chat_input])
1366
- # )
1367
-
1368
- # chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output])
1369
- # chat_input.submit(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query").then(
1370
- # fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response"
1371
- # ).then(
1372
- # fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder"
1373
- # ).then(
1374
- # fn=clear_textbox, inputs=[], outputs=[chat_input]
1375
- # )
1376
-
1377
- # audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1378
- # audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1379
-
1380
- # # retrieval_mode.change(fn=handle_retrieval_mode_change, inputs=retrieval_mode, outputs=[choice, choice])
1381
- # model_choice.change(fn=handle_retrieval_mode_change, inputs=model_choice, outputs=[choice, retrieval_mode])
1382
-
1383
-
1384
- # # with gr.Column():
1385
- # # weather_output = gr.HTML(value=fetch_local_weather())
1386
- # # news_output = gr.HTML(value=fetch_local_news())
1387
- # # events_output = gr.HTML(value=fetch_local_events())
1388
-
1389
-
1390
 
1391
- # demo.queue()
1392
- # demo.launch(share=True)
1393
 
1394
 
1395
 
1396
- # with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1397
- # with gr.Row():
1398
- # with gr.Column():
1399
- # state = gr.State()
1400
-
1401
- # chatbot = gr.Chatbot([], elem_id="RADAR:Channel 94.1", bubble_full_width=False)
1402
- # choice = gr.Radio(label="Select Style", choices=["Details", "Conversational"], value="Conversational")
1403
- # retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1404
- # model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1405
-
1406
- # # Link the dropdown change to handle_model_choice_change
1407
- # model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])
1408
-
1409
- # gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
1410
-
1411
- # chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1412
- # tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
1413
- # retriever_button = gr.Button("Retriever")
1414
-
1415
- # clear_button = gr.Button("Clear")
1416
- # clear_button.click(lambda: [None, None], outputs=[chat_input, state])
1417
-
1418
- # gr.Markdown("<h1 style='color: red;'>Radar Map</h1>", elem_id="Map-Radar")
1419
- # location_output = gr.HTML()
1420
- # audio_output = gr.Audio(interactive=False, autoplay=True)
1421
-
1422
- # def stop_audio():
1423
- # audio_output.stop()
1424
- # return None
1425
-
1426
- # retriever_sequence = (
1427
- # retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="Ask_Retriever")
1428
- # .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query")
1429
- # .then(fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response")
1430
- # .then(fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder")
1431
- # .then(fn=clear_textbox, inputs=[], outputs=[chat_input])
1432
- # )
1433
-
1434
- # chat_input.submit(fn=stop_audio, inputs=[], outputs=[audio_output])
1435
- # chat_input.submit(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query").then(
1436
- # fn=bot, inputs=[chatbot, choice, tts_choice, retrieval_mode, model_choice], outputs=[chatbot, audio_output], api_name="generate_voice_response"
1437
- # ).then(
1438
- # fn=show_map_if_details, inputs=[chatbot, choice], outputs=[location_output, location_output], api_name="map_finder"
1439
- # ).then(
1440
- # fn=clear_textbox, inputs=[], outputs=[chat_input]
1441
- # )
1442
-
1443
- # audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1444
- # audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1445
-
1446
- # # with gr.Column():
1447
- # # weather_output = gr.HTML(value=fetch_local_weather())
1448
- # # news_output = gr.HTML(value=fetch_local_news())
1449
- # # events_output = gr.HTML(value=fetch_local_events())
1450
-
1451
- # demo.queue()
1452
- # demo.launch(share=True)
1453
-
1454
-
1455
-
1456
-
1457
- with gr.Blocks(theme='Pijush2023/scikit-learn-pijush', css=css) as demo:
1458
  with gr.Row():
1459
  with gr.Column():
1460
  state = gr.State()
@@ -1464,6 +1144,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush', css=css) as demo:
1464
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1465
  model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1466
 
 
1467
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])
1468
 
1469
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
@@ -1479,6 +1160,10 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush', css=css) as demo:
1479
  location_output = gr.HTML()
1480
  audio_output = gr.Audio(interactive=False, autoplay=True)
1481
 
 
 
 
 
1482
  retriever_sequence = (
1483
  retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="Ask_Retriever")
1484
  .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query")
@@ -1499,20 +1184,10 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush', css=css) as demo:
1499
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1500
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1501
 
1502
- with gr.Column():
1503
- with gr.Tab("Base"):
1504
- with gr.Row():
1505
- with gr.Column():
1506
- input_text = gr.Textbox(label="Input Text", lines=2, value="Please surprise me and speak in whatever voice you enjoy.", elem_id="input_text")
1507
- description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
1508
- play_seconds = gr.Slider(3.0, 7.0, value=3.0, step=2, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps")
1509
- run_button = gr.Button("Generate Audio", variant="primary")
1510
- with gr.Column():
1511
- audio_out = gr.Audio(label="Parler-TTS generation", format="mp3", elem_id="audio_out", streaming=True, autoplay=True)
1512
-
1513
- inputs = [input_text, description, play_seconds]
1514
- outputs = [audio_out]
1515
- run_button.click(fn=generate_base, inputs=inputs, outputs=outputs, queue=True)
1516
 
1517
  demo.queue()
1518
  demo.launch(share=True)
 
722
  return None
723
 
724
 
 
 
725
  repo_id = "parler-tts/parler-tts-mini-v1"
726
+
727
  parler_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
728
  parler_tokenizer = AutoTokenizer.from_pretrained(repo_id)
729
  parler_feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
 
791
  logging.debug(f"Audio saved to {combined_audio_path}")
792
  return combined_audio_path
793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
 
796
 
 
1130
 
1131
  return flight_info
1132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
 
 
 
1134
 
1135
 
1136
 
1137
+ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
  with gr.Row():
1139
  with gr.Column():
1140
  state = gr.State()
 
1144
  retrieval_mode = gr.Radio(label="Retrieval Mode", choices=["VDB", "KGF"], value="VDB")
1145
  model_choice = gr.Dropdown(label="Choose Model", choices=["GPT-4o", "Phi-3.5"], value="GPT-4o")
1146
 
1147
+ # Link the dropdown change to handle_model_choice_change
1148
  model_choice.change(fn=handle_model_choice_change, inputs=model_choice, outputs=[retrieval_mode, choice, choice])
1149
 
1150
  gr.Markdown("<h1 style='color: red;'>Talk to RADAR</h1>", elem_id="voice-markdown")
 
1160
  location_output = gr.HTML()
1161
  audio_output = gr.Audio(interactive=False, autoplay=True)
1162
 
1163
+ def stop_audio():
1164
+ audio_output.stop()
1165
+ return None
1166
+
1167
  retriever_sequence = (
1168
  retriever_button.click(fn=stop_audio, inputs=[], outputs=[audio_output], api_name="Ask_Retriever")
1169
  .then(fn=add_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], api_name="voice_query")
 
1184
  audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1)
1185
  audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="voice_query_to_text")
1186
 
1187
+ # with gr.Column():
1188
+ # weather_output = gr.HTML(value=fetch_local_weather())
1189
+ # news_output = gr.HTML(value=fetch_local_news())
1190
+ # events_output = gr.HTML(value=fetch_local_events())
 
 
 
 
 
 
 
 
 
 
1191
 
1192
  demo.queue()
1193
  demo.launch(share=True)