mirror of
https://github.com/leon-ai/leon.git
synced 2024-11-27 16:16:48 +03:00
feat: support (very) long speech on ASR
This commit is contained in:
parent
37fe307035
commit
2fac3881eb
@ -110,7 +110,9 @@ export default class Brain {
|
||||
!isTalkingWithVoice &&
|
||||
options.shouldInterrupt
|
||||
) {
|
||||
// Tell client to interrupt the current speech
|
||||
SOCKET_SERVER.socket?.emit('tts-interruption')
|
||||
// Cancel all the future speeches
|
||||
TTS.speeches = []
|
||||
LogHelper.info('Leon got interrupted by voice')
|
||||
}
|
||||
|
@ -1,6 +1,40 @@
|
||||
import type { ChunkData } from '@/core/tcp-client'
|
||||
import { STTParserBase } from '@/core/stt/stt-parser-base'
|
||||
import { LogHelper } from '@/helpers/log-helper'
|
||||
import { BRAIN, PYTHON_TCP_CLIENT, SOCKET_SERVER } from '@/core'
|
||||
import { BRAIN, SOCKET_SERVER } from '@/core'
|
||||
|
||||
interface EventHandler {
|
||||
[key: string]: (firstEvent: ChunkData) => void
|
||||
}
|
||||
|
||||
const NEW_SPEECH_EVENT = 'asr-new-speech'
|
||||
const END_OF_OWNER_SPEECH_DETECTED_EVENT = 'asr-end-of-owner-speech-detected'
|
||||
const ACTIVE_LISTENING_DURATION_INCREASED_EVENT =
|
||||
'asr-active-listening-duration-increased'
|
||||
|
||||
const EVENT_HANDLERS: EventHandler = {
|
||||
[NEW_SPEECH_EVENT]: (firstEvent): void => {
|
||||
/**
|
||||
* If Leon is talking with voice, then interrupt him
|
||||
*/
|
||||
if (BRAIN.isTalkingWithVoice) {
|
||||
BRAIN.setIsTalkingWithVoice(false, { shouldInterrupt: true })
|
||||
}
|
||||
|
||||
// Send the owner speech to the client
|
||||
SOCKET_SERVER.socket?.emit('asr-speech', firstEvent.data['text'])
|
||||
},
|
||||
|
||||
[END_OF_OWNER_SPEECH_DETECTED_EVENT]: (firstEvent): void => {
|
||||
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
|
||||
completeSpeech: firstEvent.data['utterance']
|
||||
})
|
||||
},
|
||||
|
||||
[ACTIVE_LISTENING_DURATION_INCREASED_EVENT]: (): void => {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
||||
export default class LocalParser extends STTParserBase {
|
||||
protected readonly name = 'Local STT Parser'
|
||||
@ -19,36 +53,75 @@ export default class LocalParser extends STTParserBase {
|
||||
}
|
||||
|
||||
/**
|
||||
* Read audio buffer and return the transcript (decoded string)
|
||||
* Parse the string chunk and emit the events to the client
|
||||
* @param strChunk - The string chunk to parse. E.g. `{"topic": "asr-new-speech", "data": {"text": " the other day I was thinking about the"}}{"topic": "asr-new-speech", "data": {"text": " magic number but"}}`
|
||||
*/
|
||||
public async parse(): Promise<string | null> {
|
||||
const newSpeechEventName = 'asr-new-speech'
|
||||
const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected'
|
||||
const newSpeechEventHasListeners =
|
||||
PYTHON_TCP_CLIENT.ee.listenerCount(newSpeechEventName) > 0
|
||||
const endOfOwnerSpeechDetectedHasListeners =
|
||||
PYTHON_TCP_CLIENT.ee.listenerCount(endOfOwnerSpeechDetected) > 0
|
||||
public async parse(strChunk: string): Promise<string | null> {
|
||||
const rawEvents = strChunk.match(/{"topic": "asr-[^}]+}/g)
|
||||
|
||||
if (!newSpeechEventHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on(newSpeechEventName, (data) => {
|
||||
/**
|
||||
* If Leon is talking with voice, then interrupt him
|
||||
*/
|
||||
if (BRAIN.isTalkingWithVoice) {
|
||||
BRAIN.setIsTalkingWithVoice(false, { shouldInterrupt: true })
|
||||
}
|
||||
|
||||
// Send the owner speech to the client
|
||||
SOCKET_SERVER.socket?.emit('asr-speech', data.text)
|
||||
})
|
||||
if (!rawEvents) {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.error(`No topics found in the chunk: ${strChunk}`)
|
||||
return null
|
||||
}
|
||||
|
||||
if (!endOfOwnerSpeechDetectedHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on(endOfOwnerSpeechDetected, (data) => {
|
||||
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
|
||||
completeSpeech: data.utterance
|
||||
})
|
||||
})
|
||||
let events: ChunkData[] = rawEvents.map((topic) => {
|
||||
return JSON.parse(`${topic}}`)
|
||||
})
|
||||
const [firstEvent] = events
|
||||
|
||||
if (!firstEvent) {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.error(`No first event found in the chunk: ${strChunk}`)
|
||||
return null
|
||||
}
|
||||
|
||||
// Verify if all topics are similar to be ready to merge them
|
||||
const areAllTopicsSimilar = events.every(
|
||||
(event) => event.topic === firstEvent?.topic
|
||||
)
|
||||
if (areAllTopicsSimilar) {
|
||||
try {
|
||||
/**
|
||||
* Merge the topics in one and concat the text
|
||||
* if all topics are a new speech event
|
||||
*/
|
||||
if (firstEvent.topic === NEW_SPEECH_EVENT) {
|
||||
const mergedText = events
|
||||
.map((event) => event.data['text'])
|
||||
.join(' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim()
|
||||
|
||||
events = [{ topic: NEW_SPEECH_EVENT, data: { text: mergedText } }]
|
||||
}
|
||||
|
||||
/**
|
||||
* Can handle additional merge here if needed...
|
||||
*/
|
||||
} catch (e) {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.error(`Failed to merge the topics: ${e}`)
|
||||
LogHelper.error(`Events: ${events}`)
|
||||
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const [updatedEvent]: ChunkData[] = events
|
||||
|
||||
if (!updatedEvent) {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.error(`No updated event found in the chunk: ${strChunk}`)
|
||||
return null
|
||||
}
|
||||
|
||||
const handler = EVENT_HANDLERS[updatedEvent.topic]
|
||||
if (handler) {
|
||||
handler(updatedEvent)
|
||||
} else {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.error(`No handler found for the topic: ${updatedEvent?.topic}`)
|
||||
}
|
||||
|
||||
return null
|
||||
|
@ -1,5 +1,5 @@
|
||||
export abstract class STTParserBase {
|
||||
protected abstract name: string
|
||||
|
||||
protected abstract parse(buffer: Buffer): Promise<string | null>
|
||||
protected abstract parse(buffer: Buffer | string): Promise<string | null>
|
||||
}
|
||||
|
@ -81,16 +81,6 @@ export default class STT {
|
||||
)
|
||||
this._parser = new parser() as STTParser
|
||||
|
||||
/**
|
||||
* If the provider is local, parse an empty buffer to
|
||||
* initialize the parser
|
||||
*/
|
||||
if (STT_PROVIDER === STTProviders.Local) {
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-expect-error
|
||||
this._parser?.parse()
|
||||
}
|
||||
|
||||
LogHelper.title('STT')
|
||||
LogHelper.success('STT initialized')
|
||||
|
||||
@ -110,6 +100,8 @@ export default class STT {
|
||||
}
|
||||
|
||||
const buffer = fs.readFileSync(audioFilePath)
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-expect-error
|
||||
const transcript = await this._parser?.parse(buffer)
|
||||
|
||||
if (transcript && transcript !== '') {
|
||||
|
@ -13,9 +13,9 @@ const INTERVAL = IS_PRODUCTION_ENV ? 3000 : 500
|
||||
// Number of retries to connect to the TCP server
|
||||
const RETRIES_NB = IS_PRODUCTION_ENV ? 8 : 30
|
||||
|
||||
interface ChunkData {
|
||||
export interface ChunkData {
|
||||
topic: string
|
||||
data: unknown
|
||||
data: Record<string, unknown>
|
||||
}
|
||||
type TCPClientName = 'Python'
|
||||
|
||||
@ -61,20 +61,29 @@ export default class TCPClient {
|
||||
LogHelper.title(`${this.name} TCP Client`)
|
||||
LogHelper.info(`Received data: ${String(chunk)}`)
|
||||
|
||||
const data = JSON.parse(String(chunk))
|
||||
const strChunk = String(chunk)
|
||||
|
||||
/**
|
||||
* If the topic is related to ASR, then parse the data
|
||||
* If the topic is related to ASR, then parse the data manually
|
||||
* in the local STT parser
|
||||
*/
|
||||
if (data.topic.includes('asr-')) {
|
||||
if (strChunk.includes('"topic": "asr-')) {
|
||||
if (STT_PROVIDER === STTProviders.Local) {
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-expect-error
|
||||
STT.parser?.parse()
|
||||
STT.parser?.parse(strChunk)
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
const data = JSON.parse(strChunk)
|
||||
|
||||
this.ee.emit(data.topic, data.data)
|
||||
} catch (e) {
|
||||
LogHelper.title(`${this.name} TCP Client`)
|
||||
LogHelper.error(`Failed to parse the data: ${e}`)
|
||||
LogHelper.error(`Received data: ${String(chunk)}`)
|
||||
}
|
||||
}
|
||||
|
||||
this.ee.emit(data.topic, data.data)
|
||||
})
|
||||
|
||||
this.tcpSocket.on('error', (err: NodeJS.ErrnoException) => {
|
||||
|
@ -6,6 +6,8 @@ import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from ..constants import ASR_MODEL_PATH_FOR_GPU, ASR_MODEL_PATH_FOR_CPU
|
||||
from ..utils import ThrottledCallback
|
||||
|
||||
|
||||
class ASR:
|
||||
def __init__(self,
|
||||
@ -37,6 +39,13 @@ class ASR:
|
||||
|
||||
self.transcription_callback = transcription_callback
|
||||
self.wake_word_or_active_listening_callback = wake_word_or_active_listening_callback
|
||||
"""
|
||||
Throttle the wake word or active listening callback to avoid sending too many messages to the client.
|
||||
The callback is called at most once every x seconds
|
||||
"""
|
||||
self.throttled_wake_word_or_active_listening_callback = ThrottledCallback(
|
||||
wake_word_or_active_listening_callback, 0.5
|
||||
)
|
||||
self.end_of_owner_speech_callback = end_of_owner_speech_callback
|
||||
|
||||
self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon", "heilion", "alion", "hyleon"]
|
||||
@ -51,6 +60,7 @@ class ASR:
|
||||
self.is_active_listening_enabled = False
|
||||
self.saved_utterances = []
|
||||
self.segment_text = ''
|
||||
self.complete_text = ''
|
||||
|
||||
self.audio_format = pyaudio.paInt16
|
||||
self.channels = 1
|
||||
@ -65,7 +75,11 @@ class ASR:
|
||||
"""
|
||||
self.base_active_listening_duration = 12
|
||||
self.active_listening_duration = self.base_active_listening_duration
|
||||
self.buffer_size = 64 # Size of the circular buffer
|
||||
"""
|
||||
Size of the circular buffer.
|
||||
Meaning how many audio frames can be stored in the buffer
|
||||
"""
|
||||
self.buffer_size = 256
|
||||
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
@ -103,6 +117,8 @@ class ASR:
|
||||
return False
|
||||
|
||||
def process_circular_buffer(self):
|
||||
self.complete_text = ''
|
||||
|
||||
if len(self.circular_buffer) > self.buffer_size:
|
||||
self.circular_buffer.pop(0)
|
||||
|
||||
@ -130,7 +146,7 @@ class ASR:
|
||||
if has_dected_wake_word or self.is_active_listening_enabled:
|
||||
if has_dected_wake_word:
|
||||
self.log('Wake word detected')
|
||||
self.wake_word_or_active_listening_callback(segment.text)
|
||||
self.complete_text += segment.text
|
||||
self.is_wake_word_detected = True
|
||||
self.is_active_listening_enabled = True
|
||||
self.log('Active listening enabled')
|
||||
@ -138,6 +154,8 @@ class ASR:
|
||||
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||
|
||||
self.segment_text = ''
|
||||
if self.complete_text:
|
||||
self.throttled_wake_word_or_active_listening_callback(self.complete_text)
|
||||
|
||||
def start_recording(self):
|
||||
self.stream = self.audio.open(format=self.audio_format,
|
||||
@ -172,8 +190,11 @@ class ASR:
|
||||
if is_end_of_speech:
|
||||
if len(self.utterance) > 0:
|
||||
self.log('Reset')
|
||||
# Take last utterance of the utterance list
|
||||
self.end_of_owner_speech_callback(self.utterance[-1])
|
||||
# Send the latest up-to-date text
|
||||
self.wake_word_or_active_listening_callback(self.complete_text)
|
||||
time.sleep(0.1)
|
||||
# Notify the end of the owner's speech
|
||||
self.end_of_owner_speech_callback(self.complete_text)
|
||||
|
||||
if self.is_wake_word_detected:
|
||||
self.saved_utterances.append(" ".join(self.utterance))
|
||||
|
14
tcp_server/src/lib/utils.py
Normal file
14
tcp_server/src/lib/utils.py
Normal file
@ -0,0 +1,14 @@
|
||||
import time
|
||||
|
||||
|
||||
class ThrottledCallback:
|
||||
def __init__(self, callback, min_interval):
|
||||
self.callback = callback
|
||||
self.min_interval = min_interval
|
||||
self.last_call = 0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
current_time = time.time()
|
||||
if current_time - self.last_call > self.min_interval:
|
||||
self.callback(*args, **kwargs)
|
||||
self.last_call = current_time
|
Loading…
Reference in New Issue
Block a user