Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, Wav2Vec2ProcessorWithLM | |
from pyannote.audio import Pipeline | |
from librosa import load, resample | |
from rpunct import RestorePuncts | |
# Audio components | |
asr_model = 'patrickvonplaten/wav2vec2-base-960h-4-gram' | |
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model) | |
asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder) | |
speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation") | |
rpunct = RestorePuncts() | |
# Text components | |
sentiment_pipeline = pipeline('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english") | |
sentiment_threshold = 0.75 | |
EXAMPLES = ["example_audio.wav"] | |
def speech_to_text(speech): | |
speaker_output = speaker_segmentation(speech) | |
speech, sampling_rate = load(speech) | |
if sampling_rate != 16000: | |
speech = resample(speech, sampling_rate, 16000) | |
text = asr(speech, return_timestamps="word") | |
full_text = text['text'].lower() | |
chunks = text['chunks'] | |
diarized_output = [] | |
i = 0 | |
speaker_counter = 0 | |
# New iteration every time the speaker changes | |
for turn, _, _ in speaker_output.itertracks(yield_label=True): | |
speaker = "Speaker 0" if speaker_counter % 2 == 0 else "Speaker 1" | |
diarized = "" | |
while i < len(chunks) and chunks[i]['timestamp'][1] <= turn.end: | |
diarized += chunks[i]['text'].lower() + ' ' | |
i += 1 | |
if diarized != "": | |
diarized = rpunct.punctuate(diarized) | |
diarized_output.extend([(diarized, speaker), ('from {:.2f}-{:.2f}'.format(turn.start, turn.end), None)]) | |
speaker_counter += 1 | |
return diarized_output, full_text | |
def sentiment(checked_options, diarized): | |
customer_id = checked_options | |
customer_sentiments = [] | |
for transcript in diarized: | |
speaker_speech, speaker_id = transcript | |
if speaker_id == customer_id: | |
output = sentiment_pipeline(speaker_speech)[0] | |
if output["label"] != "neutral" and output["score"] > sentiment_threshold: | |
customer_sentiments.append((speaker_speech, output["label"])) | |
else: | |
customer_sentiments.append(speaker_speech, None) | |
return customer_sentiments | |
demo = gr.Blocks(enable_queue=True) | |
demo.encrypt = False | |
with demo: | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Audio file", type='filepath') | |
with gr.Row(): | |
btn = gr.Button("Transcribe") | |
with gr.Row(): | |
examples = gr.components.Dataset(components=[audio], samples=[EXAMPLES], type="index") | |
with gr.Column(): | |
gr.Markdown("**Diarized Output:**") | |
diarized = gr.HighlightedText(lines=5, label="Diarized Output") | |
full = gr.Textbox(lines=4, label="Full Transcript") | |
check = gr.Radio(["Speaker 0", "Speaker 1"], label='Choose speaker for sentiment analysis') | |
analyzed = gr.HighlightedText(label="Customer Sentiment") | |
btn.click(speech_to_text, audio, [diarized, full]) | |
check.change(sentiment, [check, diarized], analyzed) | |
def cache_example(example): | |
processed_examples = audio.preprocess_example(example) | |
diarized_output, full_text = speech_to_text(processed_examples ) | |
return processed_examples, diarized_output, full_text | |
cache = [cache_example(e) for e in EXAMPLES] | |
def load_example(example_id): | |
return cache[example_id] | |
examples._click_no_postprocess(load_example, inputs=[examples], outputs=[audio, diarized, full], queue=False) | |
demo.launch() |