1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-27 16:16:48 +03:00

feat: support incremental active listening for ASR

This commit is contained in:
louistiti 2024-05-24 01:26:28 +08:00
parent c54ad18cc3
commit c1349c930a
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
5 changed files with 60 additions and 21 deletions

View File

@ -22,7 +22,7 @@ export default class LocalParser extends STTParserBase {
* Read audio buffer and return the transcript (decoded string)
*/
public async parse(): Promise<string | null> {
const wakeWordEventName = 'asr-wake-word-detected'
const wakeWordEventName = 'asr-new-speech'
const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected'
const wakeWordEventHasListeners =
PYTHON_TCP_CLIENT.ee.listenerCount(wakeWordEventName) > 0

View File

@ -59,6 +59,8 @@ export default class LocalSynthesizer extends TTSSynthesizerBase {
try {
const duration = await this.getAudioDuration(outputPath)
TTS.em.emit('saved', duration)
PYTHON_TCP_CLIENT.emit('leon-speech-audio-ended', duration / 1_000)
} catch (e) {
LogHelper.title(this.name)
LogHelper.warning(`Failed to get audio duration: ${e}`)

View File

@ -11,7 +11,7 @@ class ASR:
def __init__(self,
device='auto',
transcription_callback=None,
wake_word_callback=None,
wake_word_or_active_listening_callback=None,
end_of_owner_speech_callback=None):
tic = time.perf_counter()
self.log('Loading model...')
@ -36,7 +36,7 @@ class ASR:
self.compute_type = compute_type
self.transcription_callback = transcription_callback
self.wake_word_callback = wake_word_callback
self.wake_word_or_active_listening_callback = wake_word_or_active_listening_callback
self.end_of_owner_speech_callback = end_of_owner_speech_callback
self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon", "heilion", "alion", "hyleon"]
@ -48,6 +48,7 @@ class ASR:
self.is_voice_activity_detected = False
self.silence_start_time = 0
self.is_wake_word_detected = False
self.is_active_listening_enabled = False
self.saved_utterances = []
self.segment_text = ''
@ -56,7 +57,14 @@ class ASR:
self.rate = 16000
self.chunk = 4096
self.threshold = 128
self.silence_duration = 1.5 # duration of silence in seconds
# Duration of silence after which the audio data is considered as a new utterance (in seconds)
self.silence_duration = 1.5
"""
Duration of silence after which the active listening is stopped (in seconds).
Once stopped, the active listening can be resumed by saying the wake word again
"""
self.base_active_listening_duration = 12
self.active_listening_duration = self.base_active_listening_duration
self.buffer_size = 64 # Size of the circular buffer
self.audio = pyaudio.PyAudio()
@ -117,10 +125,15 @@ class ASR:
if self.is_wake_word_detected:
self.utterance.append(self.segment_text)
self.transcription_callback(" ".join(self.utterance))
if self.detect_wake_word(segment.text):
self.log('Wake word detected')
self.wake_word_callback(segment.text)
has_dected_wake_word = self.detect_wake_word(segment.text)
if has_dected_wake_word or self.is_active_listening_enabled:
if has_dected_wake_word:
self.log('Wake word detected')
self.wake_word_or_active_listening_callback(segment.text)
self.is_wake_word_detected = True
self.is_active_listening_enabled = True
self.log('Active listening enabled')
else:
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
@ -155,7 +168,8 @@ class ASR:
if self.is_voice_activity_detected:
self.silence_start_time = time.time()
self.is_voice_activity_detected = False
if time.time() - self.silence_start_time > self.silence_duration: # If silence for SILENCE_DURATION seconds
is_end_of_speech = time.time() - self.silence_start_time > self.silence_duration
if is_end_of_speech:
if len(self.utterance) > 0:
self.log('Reset')
# Take last utterance of the utterance list
@ -164,9 +178,16 @@ class ASR:
if self.is_wake_word_detected:
self.saved_utterances.append(" ".join(self.utterance))
self.utterance = []
self.is_wake_word_detected = False
# self.is_wake_word_detected = False
self.circular_buffer = []
should_stop_active_listening = self.is_active_listening_enabled and time.time() - self.silence_start_time > self.active_listening_duration
if should_stop_active_listening:
self.is_wake_word_detected = False
self.is_active_listening_enabled = False
self.log('Active listening disabled')
# self.log('Silence detected')
def stop_recording(self):

View File

@ -56,7 +56,7 @@ class TCPServer:
device='auto',
config_path=TTS_MODEL_CONFIG_PATH,
ckpt_path=TTS_MODEL_PATH
)
)
def init_asr(self):
if not IS_ASR_ENABLED:
@ -67,8 +67,8 @@ class TCPServer:
# self.log('Transcription:', utterance)
pass
def clean_up_wake_word_text(text: str) -> str:
"""Remove everything before the wake word (included), remove punctuation right after it, trim and
def clean_up_speech(text: str) -> str:
"""Remove everything before the wake word if there is (included), remove punctuation right after it, trim and
capitalize the first letter"""
lowercased_text = text.lower()
for wake_word in self.asr.wake_words:
@ -76,7 +76,8 @@ class TCPServer:
start_index = lowercased_text.index(wake_word)
end_index = start_index + len(wake_word)
end_whitespace_index = end_index
while end_whitespace_index < len(text) and (text[end_whitespace_index] in string.whitespace + string.punctuation):
while end_whitespace_index < len(text) and (
text[end_whitespace_index] in string.whitespace + string.punctuation):
end_whitespace_index += 1
cleaned_text = text[end_whitespace_index:].strip()
if cleaned_text: # Check if cleaned_text is not empty
@ -85,11 +86,11 @@ class TCPServer:
return "" # Return an empty string if cleaned_text is empty
return text
def wake_word_callback(text):
cleaned_text = clean_up_wake_word_text(text)
self.log('Wake word detected:', cleaned_text)
def wake_word_or_active_listening_callback(text):
cleaned_text = clean_up_speech(text)
self.log('Cleaned speech:', cleaned_text)
self.send_tcp_message({
'topic': 'asr-wake-word-detected',
'topic': 'asr-new-speech',
'data': {
'text': cleaned_text
}
@ -106,9 +107,9 @@ class TCPServer:
self.asr = ASR(device='auto',
transcription_callback=transcription_callback,
wake_word_callback=wake_word_callback,
wake_word_or_active_listening_callback=wake_word_or_active_listening_callback,
end_of_owner_speech_callback=end_of_owner_speech_callback
)
)
self.asr.start_recording()
def init(self):
@ -170,7 +171,7 @@ class TCPServer:
}
}
def tts_synthesize(self, speech: str) -> Union[dict, None]:
def tts_synthesize(self, speech: str) -> dict:
# If TTS is not initialized yet, then wait for 2 seconds before synthesizing
if not self.tts:
self.log('TTS is not initialized yet. Waiting for 2 seconds before synthesizing...')
@ -213,3 +214,17 @@ class TCPServer:
'audioId': audio_id
}
}
def leon_speech_audio_ended(self, audio_duration: float) -> dict:
if self.asr:
if not audio_duration:
audio_duration = 0
self.asr.active_listening_duration = self.asr.base_active_listening_duration + audio_duration
self.log(f'ASR active listening duration increased to {self.asr.active_listening_duration}s')
return {
'topic': 'asr-active-listening-duration-increased',
'data': {
'activeListeningDuration': self.asr.active_listening_duration
}
}

View File

@ -26,4 +26,5 @@ asr_thread.start()
tcp_server.init_tts()
tcp_server.init()
tcp_server_thread = threading.Thread(target=tcp_server.init)
tcp_server_thread.start()