mirror of
https://github.com/leon-ai/leon.git
synced 2024-11-27 16:16:48 +03:00
feat(python tcp server): ASR engine communication with core
This commit is contained in:
parent
ef368f89fb
commit
f051c1d2cd
@ -91,6 +91,14 @@ export default class Client {
|
||||
this.chatbot.createBubble('leon', data)
|
||||
})
|
||||
|
||||
this.socket.on('asr-speech', (data) => {
|
||||
console.log('Wake word detected', data)
|
||||
})
|
||||
|
||||
this.socket.on('asr-end-of-owner-speech', () => {
|
||||
console.log('End of owner speech')
|
||||
})
|
||||
|
||||
/**
|
||||
* Only used for "local" TTS provider as a PoC for now.
|
||||
* Target to do a better implementation in the future
|
||||
|
48
server/src/core/stt/parsers/local-parser.ts
Normal file
48
server/src/core/stt/parsers/local-parser.ts
Normal file
@ -0,0 +1,48 @@
|
||||
import { STTParserBase } from '@/core/stt/stt-parser-base'
|
||||
import { LogHelper } from '@/helpers/log-helper'
|
||||
import { PYTHON_TCP_CLIENT, SOCKET_SERVER } from '@/core'
|
||||
|
||||
export default class LocalParser extends STTParserBase {
|
||||
protected readonly name = 'Local STT Parser'
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.success('New instance')
|
||||
|
||||
try {
|
||||
LogHelper.success('Parser initialized')
|
||||
} catch (e) {
|
||||
LogHelper.error(`${this.name} - Failed to initialize: ${e}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read audio buffer and return the transcript (decoded string)
|
||||
*/
|
||||
public async parse(): Promise<string | null> {
|
||||
const wakeWordEventName = 'asr-wake-word-detected'
|
||||
const endOfOwnerSpeechDetected = 'asr-end-of-owner-speech-detected'
|
||||
const wakeWordEventHasListeners =
|
||||
PYTHON_TCP_CLIENT.ee.listenerCount(wakeWordEventName) > 0
|
||||
const endOfOwnerSpeechDetectedHasListeners =
|
||||
PYTHON_TCP_CLIENT.ee.listenerCount(endOfOwnerSpeechDetected) > 0
|
||||
|
||||
if (!wakeWordEventHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on(wakeWordEventName, (data) => {
|
||||
SOCKET_SERVER.socket?.emit('asr-speech', data.text)
|
||||
})
|
||||
}
|
||||
|
||||
if (!endOfOwnerSpeechDetectedHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on(endOfOwnerSpeechDetected, (data) => {
|
||||
SOCKET_SERVER.socket?.emit('asr-end-of-owner-speech', {
|
||||
completeSpeech: data.utterance
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ import { STTParserNames, STTProviders } from '@/core/stt/types'
|
||||
import { LogHelper } from '@/helpers/log-helper'
|
||||
|
||||
const PROVIDERS_MAP = {
|
||||
[STTProviders.Local]: STTParserNames.Local,
|
||||
[STTProviders.GoogleCloudSTT]: STTParserNames.GoogleCloudSTT,
|
||||
[STTProviders.WatsonSTT]: STTParserNames.WatsonSTT,
|
||||
[STTProviders.CoquiSTT]: STTParserNames.CoquiSTT
|
||||
@ -76,6 +77,16 @@ export default class STT {
|
||||
)
|
||||
this.parser = new parser() as STTParser
|
||||
|
||||
/**
|
||||
* If the provider is local, parse an empty buffer to
|
||||
* initialize the parser
|
||||
*/
|
||||
if (STT_PROVIDER === STTProviders.Local) {
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-expect-error
|
||||
this.parser?.parse()
|
||||
}
|
||||
|
||||
LogHelper.title('STT')
|
||||
LogHelper.success('STT initialized')
|
||||
|
||||
|
@ -1,20 +1,24 @@
|
||||
import type LocalParser from '@/core/stt/parsers/local-parser'
|
||||
import type CoquiSTTParser from '@/core/stt/parsers/coqui-stt-parser'
|
||||
import type GoogleCloudSTTParser from '@/core/stt/parsers/google-cloud-stt-parser'
|
||||
import type WatsonSTTParser from '@/core/stt/parsers/watson-stt-parser'
|
||||
|
||||
export enum STTProviders {
|
||||
Local = 'local',
|
||||
GoogleCloudSTT = 'google-cloud-stt',
|
||||
WatsonSTT = 'watson-stt',
|
||||
CoquiSTT = 'coqui-stt'
|
||||
}
|
||||
|
||||
export enum STTParserNames {
|
||||
Local = 'local-parser',
|
||||
GoogleCloudSTT = 'google-cloud-stt-parser',
|
||||
WatsonSTT = 'watson-stt-parser',
|
||||
CoquiSTT = 'coqui-stt-parser'
|
||||
}
|
||||
|
||||
export type STTParser =
|
||||
| LocalParser
|
||||
| GoogleCloudSTTParser
|
||||
| WatsonSTTParser
|
||||
| CoquiSTTParser
|
||||
|
@ -32,11 +32,11 @@ export default class LocalSynthesizer extends TTSSynthesizerBase {
|
||||
}
|
||||
|
||||
public async synthesize(speech: string): Promise<SynthesizeResult | null> {
|
||||
const eventName = 'tts-receiving-stream'
|
||||
const eventName = 'tts-audio-streaming'
|
||||
const eventHasListeners = PYTHON_TCP_CLIENT.ee.listenerCount(eventName) > 0
|
||||
|
||||
if (!eventHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on('tts-audio-streaming', (data: ChunkData) => {
|
||||
PYTHON_TCP_CLIENT.ee.on(eventName, (data: ChunkData) => {
|
||||
/**
|
||||
* Send audio stream chunk by chunk to the client as long as
|
||||
* the temporary file is being written from the TCP server
|
||||
|
@ -1,3 +1,4 @@
|
||||
import type LocalSynthesizer from '@/core/tts/synthesizers/local-synthesizer'
|
||||
import type AmazonPollySynthesizer from '@/core/tts/synthesizers/amazon-polly-synthesizer'
|
||||
import type FliteSynthesizer from '@/core/tts/synthesizers/flite-synthesizer'
|
||||
import type GoogleCloudTTSSynthesizer from '@/core/tts/synthesizers/google-cloud-tts-synthesizer'
|
||||
@ -25,6 +26,7 @@ export interface SynthesizeResult {
|
||||
}
|
||||
|
||||
export type TTSSynthesizer =
|
||||
| LocalSynthesizer
|
||||
| AmazonPollySynthesizer
|
||||
| FliteSynthesizer
|
||||
| GoogleCloudTTSSynthesizer
|
||||
|
@ -6,7 +6,11 @@ import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
class ASR:
|
||||
def __init__(self, device='auto'):
|
||||
def __init__(self,
|
||||
device='auto',
|
||||
transcription_callback=None,
|
||||
wake_word_callback=None,
|
||||
end_of_owner_speech_callback=None):
|
||||
self.log('Loading model...')
|
||||
|
||||
if device == 'auto':
|
||||
@ -21,7 +25,14 @@ class ASR:
|
||||
|
||||
self.log(f'Device: {device}')
|
||||
|
||||
self.transcription_callback = transcription_callback
|
||||
self.wake_word_callback = wake_word_callback
|
||||
self.end_of_owner_speech_callback = end_of_owner_speech_callback
|
||||
|
||||
self.wake_words = ["ok leon", "okay leon", "hi leon", "hey leon", "hello leon"]
|
||||
|
||||
self.device = device
|
||||
self.tcp_conn = None
|
||||
self.utterance = []
|
||||
self.circular_buffer = []
|
||||
self.is_voice_activity_detected = False
|
||||
@ -47,8 +58,8 @@ class ASR:
|
||||
|
||||
def detect_wake_word(self, speech: str) -> bool:
|
||||
lowercased_speech = speech.lower().strip()
|
||||
wake_words = ["ok leon", "okay leon", "hi leon", "hey leon"]
|
||||
for wake_word in wake_words:
|
||||
|
||||
for wake_word in self.wake_words:
|
||||
if wake_word in lowercased_speech:
|
||||
return True
|
||||
return False
|
||||
@ -56,6 +67,7 @@ class ASR:
|
||||
def process_circular_buffer(self):
|
||||
if len(self.circular_buffer) > self.buffer_size:
|
||||
self.circular_buffer.pop(0)
|
||||
|
||||
audio_data = np.concatenate(self.circular_buffer)
|
||||
segments, info = self.model.transcribe(
|
||||
audio_data,
|
||||
@ -68,12 +80,17 @@ class ASR:
|
||||
for segment in segments:
|
||||
words = segment.text.split()
|
||||
self.segment_text += ' '.join(words) + ' '
|
||||
|
||||
if self.is_wake_word_detected:
|
||||
self.utterance.append(self.segment_text)
|
||||
self.transcription_callback(" ".join(self.utterance))
|
||||
if self.detect_wake_word(segment.text):
|
||||
self.log('Wake word detected')
|
||||
self.wake_word_callback(segment.text)
|
||||
self.is_wake_word_detected = True
|
||||
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||
else:
|
||||
self.log("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||
|
||||
self.segment_text = ''
|
||||
|
||||
def start_recording(self):
|
||||
@ -98,6 +115,7 @@ class ASR:
|
||||
if rms >= self.threshold: # audio threshold
|
||||
if not self.is_voice_activity_detected:
|
||||
self.is_voice_activity_detected = True
|
||||
|
||||
self.circular_buffer.append(data_np)
|
||||
self.process_circular_buffer()
|
||||
else:
|
||||
@ -107,6 +125,8 @@ class ASR:
|
||||
if time.time() - self.silence_start_time > self.silence_duration: # If silence for SILENCE_DURATION seconds
|
||||
if len(self.utterance) > 0:
|
||||
self.log('Reset')
|
||||
# Take last utterance of the utterance list
|
||||
self.end_of_owner_speech_callback(self.utterance[-1])
|
||||
|
||||
if self.is_wake_word_detected:
|
||||
self.saved_utterances.append(" ".join(self.utterance))
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
from typing import Union
|
||||
import time
|
||||
import re
|
||||
import string
|
||||
|
||||
import lib.nlp as nlp
|
||||
from .asr import ASR
|
||||
@ -25,6 +26,13 @@ class TCPServer:
|
||||
def log(*args, **kwargs):
|
||||
print('[TCP Server]', *args, **kwargs)
|
||||
|
||||
def send_tcp_message(self, data: dict):
|
||||
if not self.conn:
|
||||
self.log('No client connection found. Cannot send message')
|
||||
return
|
||||
|
||||
self.conn.sendall(json.dumps(data).encode('utf-8'))
|
||||
|
||||
def init_tts(self):
|
||||
if not IS_TTS_ENABLED:
|
||||
self.log('TTS is disabled')
|
||||
@ -49,8 +57,53 @@ class TCPServer:
|
||||
self.log('ASR is disabled')
|
||||
return
|
||||
|
||||
def transcription_callback(utterance):
|
||||
# self.log('Transcription:', utterance)
|
||||
pass
|
||||
|
||||
def clean_up_wake_word_text(text: str) -> str:
|
||||
"""Remove everything before the wake word (included), remove punctuation right after it, trim and
|
||||
capitalize the first letter"""
|
||||
lowercased_text = text.lower()
|
||||
for wake_word in self.asr.wake_words:
|
||||
if wake_word in lowercased_text:
|
||||
start_index = lowercased_text.index(wake_word)
|
||||
end_index = start_index + len(wake_word)
|
||||
end_whitespace_index = end_index
|
||||
while end_whitespace_index < len(text) and (text[end_whitespace_index] in string.whitespace + string.punctuation):
|
||||
end_whitespace_index += 1
|
||||
cleaned_text = text[end_whitespace_index:].strip()
|
||||
if cleaned_text: # Check if cleaned_text is not empty
|
||||
return cleaned_text[0].upper() + cleaned_text[1:]
|
||||
else:
|
||||
return "" # Return an empty string if cleaned_text is empty
|
||||
return text
|
||||
|
||||
def wake_word_callback(text):
|
||||
cleaned_text = clean_up_wake_word_text(text)
|
||||
self.log('Wake word detected:', cleaned_text)
|
||||
self.send_tcp_message({
|
||||
'topic': 'asr-wake-word-detected',
|
||||
'data': {
|
||||
'text': cleaned_text
|
||||
}
|
||||
})
|
||||
|
||||
def end_of_owner_speech_callback(utterance):
|
||||
self.log('End of owner speech:', utterance)
|
||||
self.send_tcp_message({
|
||||
'topic': 'asr-end-of-owner-speech-detected',
|
||||
'data': {
|
||||
'utterance': utterance
|
||||
}
|
||||
})
|
||||
|
||||
# TODO: local model path
|
||||
self.asr = ASR(device='auto')
|
||||
self.asr = ASR(device='auto',
|
||||
transcription_callback=transcription_callback,
|
||||
wake_word_callback=wake_word_callback,
|
||||
end_of_owner_speech_callback=end_of_owner_speech_callback
|
||||
)
|
||||
self.asr.start_recording()
|
||||
|
||||
def init(self):
|
||||
@ -97,7 +150,7 @@ class TCPServer:
|
||||
method = getattr(self, method)
|
||||
res = method(data)
|
||||
|
||||
self.conn.sendall(json.dumps(res).encode('utf-8'))
|
||||
self.send_tcp_message(res)
|
||||
finally:
|
||||
self.log(f'Client disconnected: {self.addr}')
|
||||
self.conn.close()
|
||||
@ -124,7 +177,7 @@ class TCPServer:
|
||||
audio_id = f'{int(time.time())}_{os.urandom(2).hex()}'
|
||||
output_file_name = f'{audio_id}.wav'
|
||||
output_path = os.path.join(TMP_PATH, output_file_name)
|
||||
speed = 0.88
|
||||
speed = 0.87
|
||||
|
||||
formatted_speech = speech.replace(' - ', '.').replace(',', '.').replace(': ', '. ')
|
||||
# Clean up emojis
|
||||
|
Loading…
Reference in New Issue
Block a user