From c1349c930a3cfe1b4edf4113ff20b4351ee9533d Mon Sep 17 00:00:00 2001 From: louistiti Date: Fri, 24 May 2024 01:26:28 +0800 Subject: [PATCH] feat: support incremental active listening for ASR --- server/src/core/stt/parsers/local-parser.ts | 2 +- .../tts/synthesizers/local-synthesizer.ts | 2 + tcp_server/src/lib/asr/api.py | 37 +++++++++++++++---- tcp_server/src/lib/tcp_server.py | 37 +++++++++++++------ tcp_server/src/main.py | 3 +- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/server/src/core/stt/parsers/local-parser.ts b/server/src/core/stt/parsers/local-parser.ts index 0a57decf..91a66f8d 100644 --- a/server/src/core/stt/parsers/local-parser.ts +++ b/server/src/core/stt/parsers/local-parser.ts @@ -22,7 +22,7 @@ export default class LocalParser extends STTParserBase { * Read audio buffer and return the transcript (decoded string) */ public async parse(): Promise { - const wakeWordEventName = 'asr-wake-word-detected' + const wakeWordEventName = 'asr-new-speech' const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected' const wakeWordEventHasListeners = PYTHON_TCP_CLIENT.ee.listenerCount(wakeWordEventName) > 0 diff --git a/server/src/core/tts/synthesizers/local-synthesizer.ts b/server/src/core/tts/synthesizers/local-synthesizer.ts index f0456342..e687bb55 100644 --- a/server/src/core/tts/synthesizers/local-synthesizer.ts +++ b/server/src/core/tts/synthesizers/local-synthesizer.ts @@ -59,6 +59,8 @@ export default class LocalSynthesizer extends TTSSynthesizerBase { try { const duration = await this.getAudioDuration(outputPath) TTS.em.emit('saved', duration) + + PYTHON_TCP_CLIENT.emit('leon-speech-audio-ended', duration / 1_000) } catch (e) { LogHelper.title(this.name) LogHelper.warning(`Failed to get audio duration: ${e}`) diff --git a/tcp_server/src/lib/asr/api.py b/tcp_server/src/lib/asr/api.py index 2f0fc877..69e49b3e 100644 --- a/tcp_server/src/lib/asr/api.py +++ b/tcp_server/src/lib/asr/api.py @@ -11,7 +11,7 @@ class ASR: def __init__(self, device='auto', transcription_callback=None, - wake_word_callback=None, + wake_word_or_active_listening_callback=None, end_of_owner_speech_callback=None): tic = time.perf_counter() self.log('Loading model...') @@ -36,7 +36,7 @@ class ASR: self.compute_type = compute_type self.transcription_callback = transcription_callback - self.wake_word_callback = wake_word_callback + self.wake_word_or_active_listening_callback = wake_word_or_active_listening_callback self.end_of_owner_speech_callback = end_of_owner_speech_callback self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon", "heilion", "alion", "hyleon"] @@ -48,6 +48,7 @@ class ASR: self.is_voice_activity_detected = False self.silence_start_time = 0 self.is_wake_word_detected = False + self.is_active_listening_enabled = False self.saved_utterances = [] self.segment_text = '' @@ -56,7 +57,14 @@ class ASR: self.rate = 16000 self.chunk = 4096 self.threshold = 128 - self.silence_duration = 1.5 # duration of silence in seconds + # Duration of silence after which the audio data is considered as a new utterance (in seconds) + self.silence_duration = 1.5 + """ + Duration of silence after which the active listening is stopped (in seconds). + Once stopped, the active listening can be resumed by saying the wake word again + """ + self.base_active_listening_duration = 12 + self.active_listening_duration = self.base_active_listening_duration self.buffer_size = 64 # Size of the circular buffer self.audio = pyaudio.PyAudio() @@ -117,10 +125,15 @@ class ASR: if self.is_wake_word_detected: self.utterance.append(self.segment_text) self.transcription_callback(" ".join(self.utterance)) - if self.detect_wake_word(segment.text): - self.log('Wake word detected') - self.wake_word_callback(segment.text) + + has_dected_wake_word = self.detect_wake_word(segment.text) + if has_dected_wake_word or self.is_active_listening_enabled: + if has_dected_wake_word: + self.log('Wake word detected') + self.wake_word_or_active_listening_callback(segment.text) self.is_wake_word_detected = True + self.is_active_listening_enabled = True + self.log('Active listening enabled') else: self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) @@ -155,7 +168,8 @@ class ASR: if self.is_voice_activity_detected: self.silence_start_time = time.time() self.is_voice_activity_detected = False - if time.time() - self.silence_start_time > self.silence_duration: # If silence for SILENCE_DURATION seconds + is_end_of_speech = time.time() - self.silence_start_time > self.silence_duration + if is_end_of_speech: if len(self.utterance) > 0: self.log('Reset') # Take last utterance of the utterance list @@ -164,9 +178,16 @@ class ASR: if self.is_wake_word_detected: self.saved_utterances.append(" ".join(self.utterance)) self.utterance = [] - self.is_wake_word_detected = False + # self.is_wake_word_detected = False self.circular_buffer = [] + + should_stop_active_listening = self.is_active_listening_enabled and time.time() - self.silence_start_time > self.active_listening_duration + if should_stop_active_listening: + self.is_wake_word_detected = False + self.is_active_listening_enabled = False + self.log('Active listening disabled') + # self.log('Silence detected') def stop_recording(self): diff --git a/tcp_server/src/lib/tcp_server.py b/tcp_server/src/lib/tcp_server.py index 37a5cb42..0d88a761 100644 --- a/tcp_server/src/lib/tcp_server.py +++ b/tcp_server/src/lib/tcp_server.py @@ -56,7 +56,7 @@ class TCPServer: device='auto', config_path=TTS_MODEL_CONFIG_PATH, ckpt_path=TTS_MODEL_PATH - ) + ) def init_asr(self): if not IS_ASR_ENABLED: @@ -67,8 +67,8 @@ class TCPServer: # self.log('Transcription:', utterance) pass - def clean_up_wake_word_text(text: str) -> str: - """Remove everything before the wake word (included), remove punctuation right after it, trim and + def clean_up_speech(text: str) -> str: + """Remove everything before the wake word if there is (included), remove punctuation right after it, trim and capitalize the first letter""" lowercased_text = text.lower() for wake_word in self.asr.wake_words: @@ -76,7 +76,8 @@ class TCPServer: start_index = lowercased_text.index(wake_word) end_index = start_index + len(wake_word) end_whitespace_index = end_index - while end_whitespace_index < len(text) and (text[end_whitespace_index] in string.whitespace + string.punctuation): + while end_whitespace_index < len(text) and ( + text[end_whitespace_index] in string.whitespace + string.punctuation): end_whitespace_index += 1 cleaned_text = text[end_whitespace_index:].strip() if cleaned_text: # Check if cleaned_text is not empty @@ -85,11 +86,11 @@ class TCPServer: return "" # Return an empty string if cleaned_text is empty return text - def wake_word_callback(text): - cleaned_text = clean_up_wake_word_text(text) - self.log('Wake word detected:', cleaned_text) + def wake_word_or_active_listening_callback(text): + cleaned_text = clean_up_speech(text) + self.log('Cleaned speech:', cleaned_text) self.send_tcp_message({ - 'topic': 'asr-wake-word-detected', + 'topic': 'asr-new-speech', 'data': { 'text': cleaned_text } @@ -106,9 +107,9 @@ class TCPServer: self.asr = ASR(device='auto', transcription_callback=transcription_callback, - wake_word_callback=wake_word_callback, + wake_word_or_active_listening_callback=wake_word_or_active_listening_callback, end_of_owner_speech_callback=end_of_owner_speech_callback - ) + ) self.asr.start_recording() def init(self): @@ -170,7 +171,7 @@ class TCPServer: } } - def tts_synthesize(self, speech: str) -> Union[dict, None]: + def tts_synthesize(self, speech: str) -> dict: # If TTS is not initialized yet, then wait for 2 seconds before synthesizing if not self.tts: self.log('TTS is not initialized yet. Waiting for 2 seconds before synthesizing...') @@ -213,3 +214,17 @@ class TCPServer: 'audioId': audio_id } } + + def leon_speech_audio_ended(self, audio_duration: float) -> dict: + if self.asr: + if not audio_duration: + audio_duration = 0 + self.asr.active_listening_duration = self.asr.base_active_listening_duration + audio_duration + self.log(f'ASR active listening duration increased to {self.asr.active_listening_duration}s') + + return { + 'topic': 'asr-active-listening-duration-increased', + 'data': { + 'activeListeningDuration': self.asr.active_listening_duration + } + } diff --git a/tcp_server/src/main.py b/tcp_server/src/main.py index 4bc22fd6..2cbfe9a5 100644 --- a/tcp_server/src/main.py +++ b/tcp_server/src/main.py @@ -26,4 +26,5 @@ asr_thread.start() tcp_server.init_tts() -tcp_server.init() +tcp_server_thread = threading.Thread(target=tcp_server.init) +tcp_server_thread.start()