File size: 6,166 Bytes
6647d0d
b371670
6647d0d
81301f1
8b6063a
6647d0d
 
 
 
 
 
81301f1
6647d0d
 
b008c2b
 
81301f1
8b6063a
6647d0d
 
 
81301f1
6647d0d
8b6063a
b371670
 
 
 
 
8b6063a
6647d0d
8b6063a
6647d0d
 
 
 
 
 
 
 
 
 
 
 
b008c2b
6647d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b008c2b
6647d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
8b6063a
6647d0d
 
 
 
 
 
 
 
81301f1
8b6063a
b371670
 
81301f1
 
 
6647d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81301f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use async_stream::stream;
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput;
use aws_sdk_transcribestreaming::primitives::Blob;
use aws_sdk_transcribestreaming::types::{
    AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream,
};
use tokio::select;
use tokio::sync::broadcast::Receiver;
use tokio_stream::Stream;
use futures_util::TryStreamExt;
use tracing::{trace, warn};
use crate::asr::{ASR, Event, slice_i16_to_u8_le};

pub struct AwsAsr {
    speaker_voice_channel: tokio::sync::mpsc::Sender<Vec<i16>>,
    speaker_transcript: tokio::sync::broadcast::Sender<Event>,
    drop_handler: Option<tokio::sync::oneshot::Sender<()>>,
}

impl Debug for AwsAsr {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "AWS_ASR")
    }
}

impl AwsAsr {
    pub async fn from_env(lang: LanguageCode) -> anyhow::Result<Self> {
        let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
        let transcript_client = aws_sdk_transcribestreaming::Client::new(&config);

        let (speaker_voice_channel, mut speaker_voice_rx) = tokio::sync::mpsc::channel::<Vec<i16>>(128);
        let (speaker_transcript, _) = tokio::sync::broadcast::channel::<Event>(128);
        let shared_speaker_transcript = speaker_transcript.clone();

        let (drop_handler, drop_rx) = tokio::sync::oneshot::channel::<()>();

        tokio::spawn(async move {
            let fut = async {
                let input_stream = stream! {
                    while let Some(raw) = speaker_voice_rx.recv().await {
                        let reshape = slice_i16_to_u8_le(&raw);
                        yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(reshape)).build()));
                    }
                };
                let output = transcript_client
                    .start_stream_transcription()
                    .language_code(lang) //LanguageCode::EnGb
                    .media_sample_rate_hertz(16000)
                    .media_encoding(MediaEncoding::Pcm)
                    .audio_stream(input_stream.into())
                    .send()
                    .await
                    .map_err(|e| StreamTranscriptionError::EstablishStreamError(Box::new(e)))?;

                let output_stream = to_stream(output);
                output_stream
                    .try_for_each(|text| async {
                        let _ = shared_speaker_transcript.send(text);
                        Ok(())
                    })
                    .await?;
                Ok(()) as anyhow::Result<()>
            };
            select! {
                res = fut => {
                    if let Err(e) = res {
                        warn!("Error: {:?}", e);
                    }
                }
                _ = drop_rx => {}
            }
        });

        Ok(Self {
            speaker_voice_channel,
            speaker_transcript,
            drop_handler: Some(drop_handler)
        })
    }
}

impl Drop for AwsAsr {
    fn drop(&mut self) {
        if let Some(drop_handler) = self.drop_handler.take() {
            let _ = drop_handler.send(());
        }
    }
}


#[async_trait]
impl ASR for AwsAsr {
    async fn frame(&mut self, frame: Vec<i16>) -> anyhow::Result<()> {
        Ok(self.speaker_voice_channel.send(frame).await?)
    }

    fn subscribe(&mut self) -> Receiver<Event> {
        self.speaker_transcript.subscribe()
    }
}

#[allow(dead_code)]
fn to_stream(
    mut output: StartStreamTranscriptionOutput,
) -> impl Stream<Item = Result<Event, StreamTranscriptionError>> {
    stream! {
        while let Some(event) = output
            .transcript_result_stream
            .recv()
            .await
            .map_err(|e| StreamTranscriptionError::TranscriptResultStreamError(Box::new(e)))? {
            match event {
                TranscriptResultStream::TranscriptEvent(transcript_event) => {
                    let Some(transcript) = transcript_event.transcript else {
                        continue
                    };

                    for result in transcript.results.unwrap_or_default() {
                        let Some(alternatives) = result.alternatives else {
                            continue
                        };
                        let Some(first_alternative) = alternatives.first() else {
                            continue
                        };
                        let Some(text) = &first_alternative.transcript else {
                            continue
                        };
                        let evt = Event {
                            transcript: text.clone(),
                            is_final: !result.is_partial,
                        };
                        yield Ok(evt);
                    }
                }
                _ => yield Err(StreamTranscriptionError::Unknown),
            }
        }
    }
}


#[derive(Debug)]
enum StreamTranscriptionError {
    EstablishStreamError(Box<dyn Error + Send + Sync>),
    TranscriptResultStreamError(Box<dyn Error + Send + Sync>),
    Unknown,
}

impl Display for StreamTranscriptionError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            StreamTranscriptionError::EstablishStreamError(e) => {
                write!(f, "EstablishStreamError: {}", e)
            }
            StreamTranscriptionError::TranscriptResultStreamError(e) => {
                write!(f, "TranscriptResultStreamError: {}", e)
            }
            StreamTranscriptionError::Unknown => write!(f, "Unknown"),
        }
    }
}

impl Error for StreamTranscriptionError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            StreamTranscriptionError::EstablishStreamError(e) => Some(e.as_ref()),
            StreamTranscriptionError::TranscriptResultStreamError(e) => Some(e.as_ref()),
            StreamTranscriptionError::Unknown => None,
        }
    }
}