1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-28 04:04:58 +03:00

feat: support (very) long speech on ASR

This commit is contained in:
louistiti 2024-05-24 18:38:00 +08:00
parent 37fe307035
commit 2fac3881eb
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
7 changed files with 161 additions and 50 deletions

View File

@ -110,7 +110,9 @@ export default class Brain {
!isTalkingWithVoice &&
options.shouldInterrupt
) {
// Tell client to interrupt the current speech
SOCKET_SERVER.socket?.emit('tts-interruption')
// Cancel all the future speeches
TTS.speeches = []
LogHelper.info('Leon got interrupted by voice')
}

View File

@ -1,6 +1,40 @@
import type { ChunkData } from '@/core/tcp-client'
import { STTParserBase } from '@/core/stt/stt-parser-base'
import { LogHelper } from '@/helpers/log-helper'
import { BRAIN, PYTHON_TCP_CLIENT, SOCKET_SERVER } from '@/core'
import { BRAIN, SOCKET_SERVER } from '@/core'
interface EventHandler {
[key: string]: (firstEvent: ChunkData) => void
}
const NEW_SPEECH_EVENT = 'asr-new-speech'
const END_OF_OWNER_SPEECH_DETECTED_EVENT = 'asr-end-of-owner-speech-detected'
const ACTIVE_LISTENING_DURATION_INCREASED_EVENT =
'asr-active-listening-duration-increased'
const EVENT_HANDLERS: EventHandler = {
[NEW_SPEECH_EVENT]: (firstEvent): void => {
/**
* If Leon is talking with voice, then interrupt him
*/
if (BRAIN.isTalkingWithVoice) {
BRAIN.setIsTalkingWithVoice(false, { shouldInterrupt: true })
}
// Send the owner speech to the client
SOCKET_SERVER.socket?.emit('asr-speech', firstEvent.data['text'])
},
[END_OF_OWNER_SPEECH_DETECTED_EVENT]: (firstEvent): void => {
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
completeSpeech: firstEvent.data['utterance']
})
},
[ACTIVE_LISTENING_DURATION_INCREASED_EVENT]: (): void => {
//
}
}
export default class LocalParser extends STTParserBase {
protected readonly name = 'Local STT Parser'
@ -19,36 +53,75 @@ export default class LocalParser extends STTParserBase {
}
/**
* Read audio buffer and return the transcript (decoded string)
* Parse the string chunk and emit the events to the client
* @param strChunk - The string chunk to parse. E.g. `{"topic": "asr-new-speech", "data": {"text": " the other day I was thinking about the"}}{"topic": "asr-new-speech", "data": {"text": " magic number but"}}`
*/
public async parse(): Promise<string | null> {
const newSpeechEventName = 'asr-new-speech'
const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected'
const newSpeechEventHasListeners =
PYTHON_TCP_CLIENT.ee.listenerCount(newSpeechEventName) > 0
const endOfOwnerSpeechDetectedHasListeners =
PYTHON_TCP_CLIENT.ee.listenerCount(endOfOwnerSpeechDetected) > 0
public async parse(strChunk: string): Promise<string | null> {
const rawEvents = strChunk.match(/{"topic": "asr-[^}]+}/g)
if (!newSpeechEventHasListeners) {
PYTHON_TCP_CLIENT.ee.on(newSpeechEventName, (data) => {
/**
* If Leon is talking with voice, then interrupt him
*/
if (BRAIN.isTalkingWithVoice) {
BRAIN.setIsTalkingWithVoice(false, { shouldInterrupt: true })
}
// Send the owner speech to the client
SOCKET_SERVER.socket?.emit('asr-speech', data.text)
})
if (!rawEvents) {
LogHelper.title(this.name)
LogHelper.error(`No topics found in the chunk: ${strChunk}`)
return null
}
if (!endOfOwnerSpeechDetectedHasListeners) {
PYTHON_TCP_CLIENT.ee.on(endOfOwnerSpeechDetected, (data) => {
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
completeSpeech: data.utterance
})
})
let events: ChunkData[] = rawEvents.map((topic) => {
return JSON.parse(`${topic}}`)
})
const [firstEvent] = events
if (!firstEvent) {
LogHelper.title(this.name)
LogHelper.error(`No first event found in the chunk: ${strChunk}`)
return null
}
// Verify if all topics are similar to be ready to merge them
const areAllTopicsSimilar = events.every(
(event) => event.topic === firstEvent?.topic
)
if (areAllTopicsSimilar) {
try {
/**
* Merge the topics in one and concat the text
* if all topics are a new speech event
*/
if (firstEvent.topic === NEW_SPEECH_EVENT) {
const mergedText = events
.map((event) => event.data['text'])
.join(' ')
.replace(/\s+/g, ' ')
.trim()
events = [{ topic: NEW_SPEECH_EVENT, data: { text: mergedText } }]
}
/**
* Can handle additional merge here if needed...
*/
} catch (e) {
LogHelper.title(this.name)
LogHelper.error(`Failed to merge the topics: ${e}`)
LogHelper.error(`Events: ${events}`)
return null
}
}
const [updatedEvent]: ChunkData[] = events
if (!updatedEvent) {
LogHelper.title(this.name)
LogHelper.error(`No updated event found in the chunk: ${strChunk}`)
return null
}
const handler = EVENT_HANDLERS[updatedEvent.topic]
if (handler) {
handler(updatedEvent)
} else {
LogHelper.title(this.name)
LogHelper.error(`No handler found for the topic: ${updatedEvent?.topic}`)
}
return null

View File

@ -1,5 +1,5 @@
export abstract class STTParserBase {
protected abstract name: string
protected abstract parse(buffer: Buffer): Promise<string | null>
protected abstract parse(buffer: Buffer | string): Promise<string | null>
}

View File

@ -81,16 +81,6 @@ export default class STT {
)
this._parser = new parser() as STTParser
/**
* If the provider is local, parse an empty buffer to
* initialize the parser
*/
if (STT_PROVIDER === STTProviders.Local) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
this._parser?.parse()
}
LogHelper.title('STT')
LogHelper.success('STT initialized')
@ -110,6 +100,8 @@ export default class STT {
}
const buffer = fs.readFileSync(audioFilePath)
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
const transcript = await this._parser?.parse(buffer)
if (transcript && transcript !== '') {

View File

@ -13,9 +13,9 @@ const INTERVAL = IS_PRODUCTION_ENV ? 3000 : 500
// Number of retries to connect to the TCP server
const RETRIES_NB = IS_PRODUCTION_ENV ? 8 : 30
interface ChunkData {
export interface ChunkData {
topic: string
data: unknown
data: Record<string, unknown>
}
type TCPClientName = 'Python'
@ -61,20 +61,29 @@ export default class TCPClient {
LogHelper.title(`${this.name} TCP Client`)
LogHelper.info(`Received data: ${String(chunk)}`)
const data = JSON.parse(String(chunk))
const strChunk = String(chunk)
/**
* If the topic is related to ASR, then parse the data
* If the topic is related to ASR, then parse the data manually
* in the local STT parser
*/
if (data.topic.includes('asr-')) {
if (strChunk.includes('"topic": "asr-')) {
if (STT_PROVIDER === STTProviders.Local) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
STT.parser?.parse()
STT.parser?.parse(strChunk)
}
} else {
try {
const data = JSON.parse(strChunk)
this.ee.emit(data.topic, data.data)
} catch (e) {
LogHelper.title(`${this.name} TCP Client`)
LogHelper.error(`Failed to parse the data: ${e}`)
LogHelper.error(`Received data: ${String(chunk)}`)
}
}
this.ee.emit(data.topic, data.data)
})
this.tcpSocket.on('error', (err: NodeJS.ErrnoException) => {

View File

@ -6,6 +6,8 @@ import numpy as np
from faster_whisper import WhisperModel
from ..constants import ASR_MODEL_PATH_FOR_GPU, ASR_MODEL_PATH_FOR_CPU
from ..utils import ThrottledCallback
class ASR:
def __init__(self,
@ -37,6 +39,13 @@ class ASR:
self.transcription_callback = transcription_callback
self.wake_word_or_active_listening_callback = wake_word_or_active_listening_callback
"""
Throttle the wake word or active listening callback to avoid sending too many messages to the client.
The callback is called at most once every x seconds
"""
self.throttled_wake_word_or_active_listening_callback = ThrottledCallback(
wake_word_or_active_listening_callback, 0.5
)
self.end_of_owner_speech_callback = end_of_owner_speech_callback
self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon", "heilion", "alion", "hyleon"]
@ -51,6 +60,7 @@ class ASR:
self.is_active_listening_enabled = False
self.saved_utterances = []
self.segment_text = ''
self.complete_text = ''
self.audio_format = pyaudio.paInt16
self.channels = 1
@ -65,7 +75,11 @@ class ASR:
"""
self.base_active_listening_duration = 12
self.active_listening_duration = self.base_active_listening_duration
self.buffer_size = 64 # Size of the circular buffer
"""
Size of the circular buffer.
Meaning how many audio frames can be stored in the buffer
"""
self.buffer_size = 256
self.audio = pyaudio.PyAudio()
self.stream = None
@ -103,6 +117,8 @@ class ASR:
return False
def process_circular_buffer(self):
self.complete_text = ''
if len(self.circular_buffer) > self.buffer_size:
self.circular_buffer.pop(0)
@ -130,7 +146,7 @@ class ASR:
if has_dected_wake_word or self.is_active_listening_enabled:
if has_dected_wake_word:
self.log('Wake word detected')
self.wake_word_or_active_listening_callback(segment.text)
self.complete_text += segment.text
self.is_wake_word_detected = True
self.is_active_listening_enabled = True
self.log('Active listening enabled')
@ -138,6 +154,8 @@ class ASR:
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
self.segment_text = ''
if self.complete_text:
self.throttled_wake_word_or_active_listening_callback(self.complete_text)
def start_recording(self):
self.stream = self.audio.open(format=self.audio_format,
@ -172,8 +190,11 @@ class ASR:
if is_end_of_speech:
if len(self.utterance) > 0:
self.log('Reset')
# Take last utterance of the utterance list
self.end_of_owner_speech_callback(self.utterance[-1])
# Send the latest up-to-date text
self.wake_word_or_active_listening_callback(self.complete_text)
time.sleep(0.1)
# Notify the end of the owner's speech
self.end_of_owner_speech_callback(self.complete_text)
if self.is_wake_word_detected:
self.saved_utterances.append(" ".join(self.utterance))

View File

@ -0,0 +1,14 @@
import time
class ThrottledCallback:
def __init__(self, callback, min_interval):
self.callback = callback
self.min_interval = min_interval
self.last_call = 0
def __call__(self, *args, **kwargs):
current_time = time.time()
if current_time - self.last_call > self.min_interval:
self.callback(*args, **kwargs)
self.last_call = current_time