|
import gradio as gr |
|
import whisper |
|
import numpy as np |
|
import openai |
|
import os |
|
from gtts import gTTS |
|
import json |
|
import hashlib |
|
import random |
|
import string |
|
import uuid |
|
from datetime import date,datetime |
|
from huggingface_hub import Repository, upload_file |
|
import shutil |
|
from helpers import dict_origin |
|
|
|
HF_TOKEN_WRITE = os.environ.get("HF_TOKEN_WRITE") |
|
print("HF_TOKEN_WRITE", HF_TOKEN_WRITE) |
|
today = date.today() |
|
today_ymd = today.strftime("%Y%m%d") |
|
|
|
def greet(name): |
|
return "Hello " + name + "!!" |
|
|
|
with open('app.css','r') as f: |
|
css_file = f.read() |
|
|
|
markdown=""" |
|
# Polish ASR BIGOS workspace |
|
""" |
|
|
|
|
|
WORKING_DATASET_REPO_URL = "https://huggingface.co/datasets/goodmike31/working-db" |
|
REPO_NAME = "goodmike31/working-db" |
|
REPOSITORY_DIR = "data" |
|
LOCAL_DIR = "data_local" |
|
os.makedirs(LOCAL_DIR,exist_ok=True) |
|
|
|
def dump_json(thing,file): |
|
with open(file,'w+',encoding="utf8") as f: |
|
json.dump(thing,f) |
|
|
|
def get_unique_name(): |
|
return ''.join([random.choice(string.ascii_letters |
|
+ string.digits) for n in range(32)]) |
|
|
|
def get_prompts(domain, type, size, language_code): |
|
print(f"Retrieving prompts for domain {domain} with method: {type} for language_code {language_code} of size {size}") |
|
promptset = ["test1", "test2"] |
|
first_prompt = promptset[0] |
|
return(promptset, first_prompt) |
|
|
|
def save_recording_and_meta(project_name, recording, transcript, language_code, spk_age, spk_accent, spk_city, spk_gender, spk_nativity, promptset, prompt_number): |
|
|
|
|
|
|
|
speaker_metadata={} |
|
speaker_metadata['gender'] = spk_gender if spk_gender !='' else 'unknown' |
|
speaker_metadata['age'] = spk_age if spk_age !='' else 'unknown' |
|
speaker_metadata['accent'] = spk_accent if spk_accent !='' else 'unknown' |
|
speaker_metadata['city'] = spk_city if spk_city !='' else 'unknown' |
|
speaker_metadata['nativity'] = spk_nativity if spk_nativity !='' else 'unknown' |
|
|
|
|
|
transcript =transcript.strip() |
|
|
|
SAVE_ROOT_DIR = os.path.join(LOCAL_DIR, project_name, today_ymd) |
|
|
|
SAVE_DIR_AUDIO = os.path.join(SAVE_ROOT_DIR, "audio") |
|
SAVE_DIR_META = os.path.join(SAVE_ROOT_DIR, "meta") |
|
os.makedirs(SAVE_DIR_AUDIO, exist_ok=True) |
|
os.makedirs(SAVE_DIR_META, exist_ok=True) |
|
|
|
|
|
|
|
|
|
uuid_name = str(uuid.uuid4()) |
|
audio_fn = uuid_name + ".wav" |
|
audio_output_fp = os.path.join(SAVE_DIR_AUDIO, audio_fn) |
|
|
|
print (f"Saving {recording} as {audio_output_fp}") |
|
shutil.copy2(recording, audio_output_fp) |
|
|
|
|
|
meta_fn = uuid_name + 'metadata.jsonl' |
|
json_file_path = os.path.join(SAVE_DIR_META, meta_fn) |
|
|
|
now = datetime.now() |
|
timestamp_str = now.strftime("%d/%m/%Y %H:%M:%S") |
|
metadata= {'id':uuid_name,'audio_file': audio_fn, |
|
'language_code':language_code, |
|
'transcript':transcript,'age': speaker_metadata['age'], |
|
'gender': speaker_metadata['gender'],'accent': speaker_metadata['accent'], |
|
'nativity': speaker_metadata['nativity'],'city': speaker_metadata['city'], |
|
"date":today_ymd, "timestamp": timestamp_str } |
|
|
|
dump_json(metadata, json_file_path) |
|
|
|
|
|
|
|
repo_audio_path = os.path.join(REPOSITORY_DIR, project_name, today_ymd, "audio", audio_fn) |
|
|
|
_ = upload_file(path_or_fileobj = audio_output_fp, |
|
path_in_repo = repo_audio_path, |
|
repo_id = REPO_NAME, |
|
repo_type = 'dataset', |
|
token = HF_TOKEN_WRITE |
|
) |
|
|
|
|
|
repo_json_path = os.path.join(REPOSITORY_DIR, project_name, today_ymd, "meta", meta_fn) |
|
_ = upload_file(path_or_fileobj = json_file_path, |
|
path_in_repo = repo_json_path, |
|
repo_id = REPO_NAME, |
|
repo_type = 'dataset', |
|
token = HF_TOKEN_WRITE |
|
) |
|
|
|
output = print(f"Recording {audio_fn} and meta file {meta_fn} successfully saved to repo!") |
|
|
|
return ["Next prompt", 1, None] |
|
|
|
|
|
def whisper_model_change(radio_whisper_model): |
|
whisper_model = whisper.load_model(radio_whisper_model) |
|
return(whisper_model) |
|
|
|
def prompt_gpt_assistant(input_text, api_key, temperature): |
|
|
|
|
|
openai.api_key = api_key |
|
|
|
|
|
system_role_message="You are a helpful assistant" |
|
|
|
messages = [ |
|
{"role": "system", "content": system_role_message}] |
|
|
|
if input_text: |
|
messages.append( |
|
{"role": "user", "content": input_text}, |
|
) |
|
|
|
chat_completion = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=messages, |
|
temperature=temperature |
|
) |
|
|
|
reply = chat_completion.choices[0].message.content |
|
|
|
return reply |
|
|
|
def voicebot_pipeline(audio): |
|
asr_out = transcribe(audio) |
|
gpt_out = prompt_gpt_assistant(asr_out) |
|
tts_out = synthesize_speech(gpt_out) |
|
return(tts_out) |
|
|
|
def transcribe(audio, language_code, whisper_model, whisper_model_type): |
|
if not whisper_model: |
|
whisper_model=init_whisper_model(whisper_model_type) |
|
|
|
print(f"Transcribing {audio} for language_code {language_code} and model {whisper_model_type}") |
|
audio = whisper.load_audio(audio) |
|
audio = whisper.pad_or_trim(audio) |
|
|
|
mel = whisper.log_mel_spectrogram(audio) |
|
|
|
options = whisper.DecodingOptions(language=language_code, without_timestamps=True, fp16=False) |
|
result = whisper.decode(whisper_model, mel, options) |
|
result_text = result.text |
|
return result_text |
|
|
|
def init_whisper_model(whisper_model_type): |
|
print("Initializing whisper model") |
|
print(whisper_model_type) |
|
whisper_model = whisper.load_model(whisper_model_type) |
|
return whisper_model |
|
|
|
def synthesize_speech(text, language_code): |
|
audioobj = gTTS(text = text, |
|
lang = language_code, |
|
slow = False) |
|
|
|
audioobj.save("Temp.mp3") |
|
return("Temp.mp3") |
|
|
|
block = gr.Blocks(css=css_file) |
|
with block: |
|
|
|
|
|
language_code = gr.State("pl") |
|
domain = gr.State() |
|
prompts_type = gr.State() |
|
promptset = gr.State("test.prompts.txt") |
|
prompt_history = gr.State() |
|
current_prompt = gr.State() |
|
prompt_number = gr.State() |
|
finished_recording = gr.State() |
|
|
|
|
|
temperature = gr.State(0) |
|
whisper_model_type = gr.State("base") |
|
whisper_model = gr.State() |
|
openai_api_key = gr.State() |
|
google_api_key = gr.State() |
|
azure_api_key = gr.State() |
|
project_name = gr.State("voicebot") |
|
|
|
spk_age = gr.State("unknown") |
|
spk_accent = gr.State("unknown") |
|
spk_city = gr.State("unknown") |
|
spk_gender = gr.State("unknown") |
|
spk_nativity = gr.State("unknown") |
|
cities = sorted(dict_origin["Poland"]["cities"]) |
|
|
|
|
|
|
|
def change_domain(choice): |
|
print("Changing promptset domain to") |
|
print(choice) |
|
domain=choice |
|
return(domain) |
|
|
|
def change_prompts_type(choice): |
|
print("Changing promptset type to") |
|
print(choice) |
|
prompts_type=choice |
|
return(prompts_type) |
|
|
|
def change_nativity(choice): |
|
print("Changing speaker nativity to") |
|
print(choice) |
|
spk_nativity=choice |
|
return(spk_nativity) |
|
|
|
def change_accent(choice): |
|
print("Changing speaker accent to") |
|
print(choice) |
|
spk_accent=choice |
|
return(spk_accent) |
|
|
|
def change_age(choice): |
|
print("Changing speaker age to") |
|
print(choice) |
|
spk_age=choice |
|
return(spk_age) |
|
|
|
def change_city(choice): |
|
print("Changing speaker city to") |
|
print(choice) |
|
spk_city=choice |
|
return(spk_city) |
|
|
|
def change_gender(choice): |
|
print("Changing speaker gender to") |
|
print(choice) |
|
spk_gender=choice |
|
return(spk_gender) |
|
|
|
def change_language(choice): |
|
if choice == "Polish": |
|
language_code="pl" |
|
print("Switching to Polish") |
|
print("language_code") |
|
print(language_code) |
|
elif choice == "English": |
|
language_code="en" |
|
print("Switching to English") |
|
print("language_code") |
|
print(language_code) |
|
return(language_code) |
|
|
|
def change_whisper_model(choice): |
|
whisper_model_type = choice |
|
print("Switching Whisper model") |
|
print(whisper_model_type) |
|
whisper_model = init_whisper_model(whisper_model_type) |
|
return [whisper_model_type, whisper_model] |
|
|
|
gr.Markdown(markdown) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem('General settings'): |
|
radio_lang = gr.Radio(["Polish", "English"], label="Language", info="If none is selected, Polish is used") |
|
radio_asr_type = gr.Radio(["Local", "Cloud"], label="Select ASR type", info="Cloud models are faster and more accurate, but costs money") |
|
with gr.Accordion(label="Local ASR settings", open=False): |
|
|
|
|
|
radio_whisper_model = gr.Radio(["tiny", "base", "small", "medium", "large"], label="Whisper ASR model (local)", info="Larger models are more accurate, but slower. Default - base") |
|
with gr.Accordion(label="Cloud ASR settings", open=False): |
|
radio_cloud_asr = gr.Radio(["Whisper", "Google", "Azure"], label="Select Cloud ASR provider", info="You need to provide API keys for specific service") |
|
with gr.Accordion(label="Cloud API Keys",open=False): |
|
gr.HTML("<p class=\"apikey\">Open AI API Key:</p>") |
|
|
|
openai_api_key = gr.Textbox(label="", elem_id="pw") |
|
gr.HTML("<p class=\"apikey\">Google Cloud API Key:</p>") |
|
|
|
google_api_key = gr.Textbox(label="", elem_id="pw") |
|
gr.HTML("<p class=\"apikey\">Azure Cloud API Key:</p>") |
|
|
|
azure_api_key = gr.Textbox(label="", elem_id="pw") |
|
with gr.Accordion(label="Chat GPT settings",open=False): |
|
slider_temp = gr.Slider(minimum=0, maximum= 2, step=0.2, label="ChatGPT temperature") |
|
|
|
with gr.TabItem('Speaker information'): |
|
with gr.Row(): |
|
dropdown_spk_nativity = gr.Dropdown(["Polish", "Other"], label="Your native language", info="") |
|
dropdown_spk_gender = gr.Dropdown(["Male", "Female", "Other", "Prefer not to say"], label="Your gender", info="") |
|
dropdown_spk_age = gr.Dropdown(["under 20", "20-29", "30-39", "40-49", "50-59", "over 60"], label="Your age range", info="") |
|
dropdown_spk_origin_city = gr.Dropdown(cities, label="Your home city", visible=True, info="Specify the closest city your place of birth and upbringing") |
|
|
|
dropdown_spk_nativity.change(fn=change_nativity, inputs=dropdown_spk_nativity, outputs=spk_age) |
|
dropdown_spk_gender.change(fn=change_gender, inputs=dropdown_spk_gender, outputs=spk_gender) |
|
dropdown_spk_age.change(fn=change_age, inputs=dropdown_spk_age, outputs=spk_age) |
|
dropdown_spk_origin_city.change(fn=change_city, inputs=dropdown_spk_origin_city, outputs=spk_city) |
|
|
|
with gr.TabItem('Voicebot playground'): |
|
mic_recording = gr.Audio(source="microphone", type="filepath", label='Record your voice') |
|
with gr.Row(): |
|
button_transcribe = gr.Button("Transcribe speech") |
|
|
|
button_save_audio_and_trans = gr.Button("Save audio recording and transcription") |
|
|
|
out_asr = gr.Textbox(placeholder="ASR output", |
|
lines=2, |
|
max_lines=5, |
|
show_label=False) |
|
|
|
with gr.Row(): |
|
button_prompt_gpt = gr.Button("Prompt ChatGPT") |
|
button_save_gpt_response = gr.Button("Save ChatGPT response") |
|
|
|
out_gpt = gr.Textbox(placeholder="ChatGPT output", |
|
lines=4, |
|
max_lines=10, |
|
show_label=False) |
|
with gr.Row(): |
|
button_synth_speech = gr.Button("Synthesize speech") |
|
button_save_synth_audio = gr.Button("Save synthetic audio") |
|
|
|
synth_recording = gr.Audio() |
|
|
|
|
|
button_save_audio_and_trans.click(save_recording_and_meta, inputs=[project_name, mic_recording, out_asr, language_code, spk_age, spk_accent, spk_city, spk_gender, spk_nativity], outputs=[]) |
|
button_transcribe.click(transcribe, inputs=[mic_recording, language_code, whisper_model,whisper_model_type], outputs=out_asr) |
|
button_prompt_gpt.click(prompt_gpt_assistant, inputs=[out_asr, openai_api_key, slider_temp], outputs=out_gpt) |
|
button_synth_speech.click(synthesize_speech, inputs=[out_gpt, language_code], outputs=synth_recording) |
|
|
|
radio_lang.change(fn=change_language, inputs=radio_lang, outputs=language_code) |
|
radio_whisper_model.change(fn=change_whisper_model, inputs=radio_whisper_model, outputs=[whisper_model_type, whisper_model]) |
|
with gr.TabItem('Batch audio collection'): |
|
|
|
|
|
with gr.Accordion(label="Promptset settings"): |
|
radio_prompts_domain = gr.Dropdown(["Bridge"], label="Select promptset domain", info="") |
|
radio_promptset_type = gr.Radio(["New promptset generation", "Existing promptset use"], label="Language", value ="Existing promptset use", info="New promptset is generated using ChatGPT") |
|
var_promptset_size = gr.Textbox(label="Specify number of prompts (min 10, max 200)") |
|
button_get_prompts = gr.Button("Save settings and get first prompt to record") |
|
|
|
prompt_text = gr.Textbox(placeholder='Prompt to be recorded',label="Prompt to be read during recording") |
|
speech_recording = gr.Audio(source="microphone",label="Select 'record from microphone' and read prompt displayed above", type="filepath") |
|
|
|
radio_prompts_domain.change(fn=change_domain, inputs=radio_prompts_domain, outputs=domain) |
|
radio_promptset_type.change(fn=change_prompts_type, inputs=radio_promptset_type, outputs=prompts_type) |
|
|
|
button_save_and_next = gr.Button("Save audio recording and move to the next prompt") |
|
button_get_prompts.click(get_prompts, inputs=[radio_prompts_domain, radio_promptset_type, var_promptset_size, language_code], outputs = [promptset, prompt_text]) |
|
|
|
button_save_and_next.click(save_recording_and_meta, inputs=[project_name, speech_recording, prompt_text, language_code, spk_age, spk_accent, spk_city, spk_gender, spk_nativity, promptset, prompt_number], outputs=[prompt_text, prompt_number, speech_recording]) |
|
|
|
block.launch() |