1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-08-16 21:50:33 +03:00

feat: prepare TCP server and core to communicate for the new TTS engine

This commit is contained in:
louistiti 2024-05-18 10:42:31 +08:00
parent a7cab344f8
commit 9a065175ca
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
7 changed files with 53 additions and 13 deletions

View File

@ -54,7 +54,7 @@
"build:nodejs-bridge": "tsx scripts/build-binaries.js nodejs-bridge",
"build:python-bridge": "tsx scripts/build-binaries.js python-bridge",
"build:tcp-server": "tsx scripts/build-binaries.js tcp-server",
"start:tcp-server": "cross-env HF_HUB_VERBOSITY=debug PIPENV_PIPFILE=tcp_server/src/Pipfile pipenv run python tcp_server/src/main.py",
"start:tcp-server": "cross-env PIPENV_PIPFILE=tcp_server/src/Pipfile pipenv run python tcp_server/src/main.py",
"start": "cross-env LEON_NODE_ENV=production node server/dist/pre-check.js && node server/dist/index.js",
"python-bridge": "cross-env PIPENV_PIPFILE=bridges/python/src/Pipfile pipenv run python bridges/python/src/main.py server/src/intent-object.sample.json",
"train": "tsx scripts/train/run-train.js",

View File

@ -34,6 +34,29 @@ import { LogHelper } from '@/helpers/log-helper'
detached: IS_DEVELOPMENT_ENV
}
)
global.pythonTCPServerProcess.stdout.on('data', (data: Buffer) => {
LogHelper.title('Python TCP Server')
LogHelper.info(data.toString())
})
global.pythonTCPServerProcess.stderr.on('data', (data: Buffer) => {
const formattedData = data.toString().trim()
const skipError = [
'RuntimeWarning:',
'FutureWarning:',
'UserWarning:',
'<00:00',
'00:00<',
'CUDNN_STATUS_NOT_SUPPORTED',
'cls.seq_relationship.weight'
]
if (skipError.some((error) => formattedData.includes(error))) {
return
}
LogHelper.title('Python TCP Server')
LogHelper.error(data.toString())
})
// Connect the Python TCP client to the Python TCP server
PYTHON_TCP_CLIENT.connect()

View File

@ -3,13 +3,20 @@ import sys
IS_RAN_FROM_BINARY = getattr(sys, 'frozen', False)
SRC_PATH = os.path.join(os.getcwd(), 'tcp_server', 'src') if not IS_RAN_FROM_BINARY else '.'
EXECUTABLE_DIR_PATH = os.path.dirname(sys.executable) if IS_RAN_FROM_BINARY else '.'
LIB_PATH = os.path.join(os.getcwd(), 'tcp_server', 'src', 'lib')
if IS_RAN_FROM_BINARY:
LIB_PATH = os.path.join(os.path.dirname(sys.executable), 'lib', 'lib')
TMP_PATH = os.path.join(LIB_PATH, 'tmp')
# TTS
TTS_MODEL_VERSION = 'V1'
TTS_MODEL_NAME = f'EN-Leon-{TTS_MODEL_VERSION}'
TTS_MODEL_ITERATION = '486000'
TTS_MODEL_NAME = f'EN-Leon-{TTS_MODEL_VERSION}-G_{TTS_MODEL_ITERATION}'
TTS_MODEL_FILE_NAME = f'{TTS_MODEL_NAME}.pth'
TTS_LIB_PATH = os.path.join(SRC_PATH, 'lib', 'tts')
TTS_LIB_PATH = os.path.join(LIB_PATH, 'tts')
TTS_MODEL_FOLDER_PATH = os.path.join(TTS_LIB_PATH, 'models')
TTS_MODEL_CONFIG_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, 'config.json')
TTS_MODEL_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, TTS_MODEL_FILE_NAME)

View File

@ -6,7 +6,7 @@ import time
import lib.nlp as nlp
from .tts.api import TTS
from .constants import TTS_MODEL_CONFIG_PATH, TTS_MODEL_PATH, IS_TTS_ENABLED
from .constants import TTS_MODEL_CONFIG_PATH, TTS_MODEL_PATH, IS_TTS_ENABLED, TMP_PATH
class TCPServer:
@ -23,8 +23,6 @@ class TCPServer:
print('[TCP Server]', *args, **kwargs)
def init_tts(self):
print('IS_TTS_ENABLED', IS_TTS_ENABLED)
# TODO: FIX IT
if not IS_TTS_ENABLED:
self.log('TTS is disabled')
return
@ -45,14 +43,10 @@ class TCPServer:
text = 'Hello, I am Leon. How can I help you?'
speaker_ids = self.tts.hps.data.spk2id
output_path = 'output.wav'
output_path = os.path.join(TMP_PATH, 'output.wav')
speed = 1.0
tic = time.perf_counter()
self.tts.tts_to_file(text, speaker_ids['EN-Leon-V1'], output_path, speed=speed)
toc = time.perf_counter()
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
self.tts.tts_to_file(text, speaker_ids['EN-Leon-V1'], output_path, speed=speed, quiet=True)
def init(self):
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use

View File

View File

@ -89,6 +89,8 @@ class TTS(nn.Module):
return texts
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
tic = time.perf_counter()
self.log(f"Generating audio for:\n{text}")
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)
audio_list = []
@ -133,6 +135,8 @@ class TTS(nn.Module):
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
toc = time.perf_counter()
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
if output_path is None:
return audio

View File

@ -1,8 +1,10 @@
from cx_Freeze import setup, Executable
import sysconfig
import sys
import os
from version import __version__
from lib.constants import TMP_PATH
"""
Increase the recursion limit to avoid RecursionError
@ -10,6 +12,16 @@ Increase the recursion limit to avoid RecursionError
"""
sys.setrecursionlimit(sys.getrecursionlimit() * 10)
"""
Delete content of all temporary directory. Only keep ".gitkeep" file.
"""
print(f"Deleting content of {TMP_PATH}")
for root, dirs, files in os.walk(TMP_PATH):
for file in files:
if file != '.gitkeep':
os.remove(os.path.join(root, file))
print(f"Deleted content of {TMP_PATH}")
"""
Instead of injecting everything from a package,
it's recommended to only include the necessary files via the