1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-27 16:16:48 +03:00

feat(python tcp server): ASR engine communication with core

This commit is contained in:
louistiti 2024-05-21 19:20:09 +08:00
parent ef368f89fb
commit f051c1d2cd
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
8 changed files with 155 additions and 9 deletions

View File

@ -91,6 +91,14 @@ export default class Client {
this.chatbot.createBubble('leon', data)
})
this.socket.on('asr-speech', (data) => {
console.log('Wake word detected', data)
})
this.socket.on('asr-end-of-owner-speech', () => {
console.log('End of owner speech')
})
/**
* Only used for "local" TTS provider as a PoC for now.
* Target to do a better implementation in the future

View File

@ -0,0 +1,48 @@
import { STTParserBase } from '@/core/stt/stt-parser-base'
import { LogHelper } from '@/helpers/log-helper'
import { PYTHON_TCP_CLIENT, SOCKET_SERVER } from '@/core'
export default class LocalParser extends STTParserBase {
protected readonly name = 'Local STT Parser'
constructor() {
super()
LogHelper.title(this.name)
LogHelper.success('New instance')
try {
LogHelper.success('Parser initialized')
} catch (e) {
LogHelper.error(`${this.name} - Failed to initialize: ${e}`)
}
}
/**
* Read audio buffer and return the transcript (decoded string)
*/
public async parse(): Promise<string | null> {
const wakeWordEventName = 'asr-wake-word-detected'
const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected'
const wakeWordEventHasListeners =
PYTHON_TCP_CLIENT.ee.listenerCount(wakeWordEventName) > 0
const endOfOwnerSpeechDetectedHasListeners =
PYTHON_TCP_CLIENT.ee.listenerCount(endOfOwnerSpeechDetected) > 0
if (!wakeWordEventHasListeners) {
PYTHON_TCP_CLIENT.ee.on(wakeWordEventName, (data) => {
SOCKET_SERVER.socket?.emit('asr-speech', data.text)
})
}
if (!endOfOwnerSpeechDetectedHasListeners) {
PYTHON_TCP_CLIENT.ee.on(endOfOwnerSpeechDetected, (data) => {
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
completeSpeech: data.utterance
})
})
}
return null
}
}

View File

@ -9,6 +9,7 @@ import { STTParserNames, STTProviders } from '@/core/stt/types'
import { LogHelper } from '@/helpers/log-helper'
const PROVIDERS_MAP = {
[STTProviders.Local]: STTParserNames.Local,
[STTProviders.GoogleCloudSTT]: STTParserNames.GoogleCloudSTT,
[STTProviders.WatsonSTT]: STTParserNames.WatsonSTT,
[STTProviders.CoquiSTT]: STTParserNames.CoquiSTT
@ -76,6 +77,16 @@ export default class STT {
)
this.parser = new parser() as STTParser
/**
* If the provider is local, parse an empty buffer to
* initialize the parser
*/
if (STT_PROVIDER === STTProviders.Local) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
this.parser?.parse()
}
LogHelper.title('STT')
LogHelper.success('STT initialized')

View File

@ -1,20 +1,24 @@
import type LocalParser from '@/core/stt/parsers/local-parser'
import type CoquiSTTParser from '@/core/stt/parsers/coqui-stt-parser'
import type GoogleCloudSTTParser from '@/core/stt/parsers/google-cloud-stt-parser'
import type WatsonSTTParser from '@/core/stt/parsers/watson-stt-parser'
export enum STTProviders {
Local = 'local',
GoogleCloudSTT = 'google-cloud-stt',
WatsonSTT = 'watson-stt',
CoquiSTT = 'coqui-stt'
}
export enum STTParserNames {
Local = 'local-parser',
GoogleCloudSTT = 'google-cloud-stt-parser',
WatsonSTT = 'watson-stt-parser',
CoquiSTT = 'coqui-stt-parser'
}
export type STTParser =
| LocalParser
| GoogleCloudSTTParser
| WatsonSTTParser
| CoquiSTTParser

View File

@ -32,11 +32,11 @@ export default class LocalSynthesizer extends TTSSynthesizerBase {
}
public async synthesize(speech: string): Promise<SynthesizeResult | null> {
const eventName = 'tts-receiving-stream'
const eventName = 'tts-audio-streaming'
const eventHasListeners = PYTHON_TCP_CLIENT.ee.listenerCount(eventName) > 0
if (!eventHasListeners) {
PYTHON_TCP_CLIENT.ee.on('tts-audio-streaming', (data: ChunkData) => {
PYTHON_TCP_CLIENT.ee.on(eventName, (data: ChunkData) => {
/**
* Send audio stream chunk by chunk to the client as long as
* the temporary file is being written from the TCP server

View File

@ -1,3 +1,4 @@
import type LocalSynthesizer from '@/core/tts/synthesizers/local-synthesizer'
import type AmazonPollySynthesizer from '@/core/tts/synthesizers/amazon-polly-synthesizer'
import type FliteSynthesizer from '@/core/tts/synthesizers/flite-synthesizer'
import type GoogleCloudTTSSynthesizer from '@/core/tts/synthesizers/google-cloud-tts-synthesizer'
@ -25,6 +26,7 @@ export interface SynthesizeResult {
}
export type TTSSynthesizer =
| LocalSynthesizer
| AmazonPollySynthesizer
| FliteSynthesizer
| GoogleCloudTTSSynthesizer

View File

@ -6,7 +6,11 @@ import numpy as np
from faster_whisper import WhisperModel
class ASR:
def __init__(self, device='auto'):
def __init__(self,
device='auto',
transcription_callback=None,
wake_word_callback=None,
end_of_owner_speech_callback=None):
self.log('Loading model...')
if device == 'auto':
@ -21,7 +25,14 @@ class ASR:
self.log(f'Device: {device}')
self.transcription_callback = transcription_callback
self.wake_word_callback = wake_word_callback
self.end_of_owner_speech_callback = end_of_owner_speech_callback
self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon"]
self.device = device
self.tcp_conn = None
self.utterance = []
self.circular_buffer = []
self.is_voice_activity_detected = False
@ -47,8 +58,8 @@ class ASR:
def detect_wake_word(self, speech: str) -> bool:
lowercased_speech = speech.lower().strip()
wake_words = ["ok leon", "okay leon", "hi leon", "hey leon"]
for wake_word in wake_words:
for wake_word in self.wake_words:
if wake_word in lowercased_speech:
return True
return False
@ -56,6 +67,7 @@ class ASR:
def process_circular_buffer(self):
if len(self.circular_buffer) > self.buffer_size:
self.circular_buffer.pop(0)
audio_data = np.concatenate(self.circular_buffer)
segments, info = self.model.transcribe(
audio_data,
@ -68,12 +80,17 @@ class ASR:
for segment in segments:
words = segment.text.split()
self.segment_text += ' '.join(words) + ' '
if self.is_wake_word_detected:
self.utterance.append(self.segment_text)
self.transcription_callback(" ".join(self.utterance))
if self.detect_wake_word(segment.text):
self.log('Wake word detected')
self.wake_word_callback(segment.text)
self.is_wake_word_detected = True
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
else:
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
self.segment_text = ''
def start_recording(self):
@ -98,6 +115,7 @@ class ASR:
if rms >= self.threshold: # audio threshold
if not self.is_voice_activity_detected:
self.is_voice_activity_detected = True
self.circular_buffer.append(data_np)
self.process_circular_buffer()
else:
@ -107,6 +125,8 @@ class ASR:
if time.time() - self.silence_start_time > self.silence_duration: # If silence for SILENCE_DURATION seconds
if len(self.utterance) > 0:
self.log('Reset')
# Take last utterance of the utterance list
self.end_of_owner_speech_callback(self.utterance[-1])
if self.is_wake_word_detected:
self.saved_utterances.append(" ".join(self.utterance))

View File

@ -4,6 +4,7 @@ import os
from typing import Union
import time
import re
import string
import lib.nlp as nlp
from .asr import ASR
@ -25,6 +26,13 @@ class TCPServer:
def log(*args, **kwargs):
print('[TCP Server]', *args, **kwargs)
def send_tcp_message(self, data: dict):
if not self.conn:
self.log('No client connection found. Cannot send message')
return
self.conn.sendall(json.dumps(data).encode('utf-8'))
def init_tts(self):
if not IS_TTS_ENABLED:
self.log('TTS is disabled')
@ -49,8 +57,53 @@ class TCPServer:
self.log('ASR is disabled')
return
def transcription_callback(utterance):
# self.log('Transcription:', utterance)
pass
def clean_up_wake_word_text(text: str) -> str:
"""Remove everything before the wake word (included), remove punctuation right after it, trim and
capitalize the first letter"""
lowercased_text = text.lower()
for wake_word in self.asr.wake_words:
if wake_word in lowercased_text:
start_index = lowercased_text.index(wake_word)
end_index = start_index + len(wake_word)
end_whitespace_index = end_index
while end_whitespace_index < len(text) and (text[end_whitespace_index] in string.whitespace + string.punctuation):
end_whitespace_index += 1
cleaned_text = text[end_whitespace_index:].strip()
if cleaned_text: # Check if cleaned_text is not empty
return cleaned_text[0].upper() + cleaned_text[1:]
else:
return "" # Return an empty string if cleaned_text is empty
return text
def wake_word_callback(text):
cleaned_text = clean_up_wake_word_text(text)
self.log('Wake word detected:', cleaned_text)
self.send_tcp_message({
'topic': 'asr-wake-word-detected',
'data': {
'text': cleaned_text
}
})
def end_of_owner_speech_callback(utterance):
self.log('End of owner speech:', utterance)
self.send_tcp_message({
'topic': 'asr-end-of-owner-speech-detected',
'data': {
'utterance': utterance
}
})
# TODO: local model path
self.asr = ASR(device='auto')
self.asr = ASR(device='auto',
transcription_callback=transcription_callback,
wake_word_callback=wake_word_callback,
end_of_owner_speech_callback=end_of_owner_speech_callback
)
self.asr.start_recording()
def init(self):
@ -97,7 +150,7 @@ class TCPServer:
method = getattr(self, method)
res = method(data)
self.conn.sendall(json.dumps(res).encode('utf-8'))
self.send_tcp_message(res)
finally:
self.log(f'Client disconnected: {self.addr}')
self.conn.close()
@ -124,7 +177,7 @@ class TCPServer:
audio_id = f'{int(time.time())}_{os.urandom(2).hex()}'
output_file_name = f'{audio_id}.wav'
output_path = os.path.join(TMP_PATH, output_file_name)
speed = 0.88
speed = 0.87
formatted_speech = speech.replace(' - ', '.').replace(',', '.').replace(': ', '. ')
# Clean up emojis