1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-07-14 21:40:34 +03:00

feat(server): complete first version of the new TTS engine

This commit is contained in:
louistiti 2024-05-18 21:45:46 +08:00
parent 9a065175ca
commit f776913ca3
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
7 changed files with 212 additions and 23 deletions

View File

@ -16,6 +16,7 @@ export default class Client {
this.chatbot = new Chatbot()
this._recorder = {}
this._suggestions = []
// this._ttsAudioContextes = {}
}
set input(newInput) {
@ -90,6 +91,26 @@ export default class Client {
this.chatbot.createBubble('leon', data)
})
/**
* Only used for "local" TTS provider as a PoC for now.
* Target to do a better implementation in the future
* with streaming support
*/
this.socket.on('tts-stream', (data) => {
// const { audioId, chunk } = data
const { chunk } = data
const ctx = new AudioContext()
// this._ttsAudioContextes[audioId] = ctx
const source = ctx.createBufferSource()
ctx.decodeAudioData(chunk, (buffer) => {
source.buffer = buffer
source.connect(ctx.destination)
source.start(0)
})
})
this.socket.on('audio-forwarded', (data, cb) => {
const ctx = new AudioContext()
const source = ctx.createBufferSource()

View File

@ -19,6 +19,7 @@ import {
} from '@/constants'
import { OSTypes } from '@/types'
import { LogHelper } from '@/helpers/log-helper'
import { LoaderHelper } from '@/helpers/loader-helper'
import { SystemHelper } from '@/helpers/system-helper'
/**
@ -56,6 +57,8 @@ BUILD_TARGETS.set('tcp-server', {
dotVenvPath: path.join(PYTHON_TCP_SERVER_SRC_PATH, '.venv')
})
;(async () => {
LoaderHelper.start()
const { argv } = process
const givenBuildTarget = argv[2].toLowerCase()

View File

@ -0,0 +1,83 @@
import fs from 'node:fs'
import type { LongLanguageCode } from '@/types'
import type { SynthesizeResult } from '@/core/tts/types'
import { LANG } from '@/constants'
import { PYTHON_TCP_CLIENT, SOCKET_SERVER, TTS } from '@/core'
import { TTSSynthesizerBase } from '@/core/tts/tts-synthesizer-base'
import { LogHelper } from '@/helpers/log-helper'
interface ChunkData {
outputPath: string
audioId: string
}
export default class LocalSynthesizer extends TTSSynthesizerBase {
protected readonly name = 'Local TTS Synthesizer'
protected readonly lang = LANG as LongLanguageCode
constructor(lang: LongLanguageCode) {
super()
LogHelper.title(this.name)
LogHelper.success('New instance')
try {
this.lang = lang
LogHelper.success('Synthesizer initialized')
} catch (e) {
LogHelper.error(`${this.name} - Failed to initialize: ${e}`)
}
}
public async synthesize(speech: string): Promise<SynthesizeResult | null> {
const eventName = 'tts-receiving-stream'
const eventHasListeners = PYTHON_TCP_CLIENT.ee.listenerCount(eventName) > 0
if (!eventHasListeners) {
PYTHON_TCP_CLIENT.ee.on('tts-audio-streaming', (data: ChunkData) => {
/**
* Send audio stream chunk by chunk to the client as long as
* the temporary file is being written from the TCP server
*/
const { outputPath, audioId } = data
const stream = fs.createReadStream(outputPath)
const chunks: Buffer[] = []
stream.on('data', (chunk: Buffer) => {
chunks.push(chunk)
// SOCKET_SERVER.socket?.emit('tts-stream', { chunk, audioId })
})
stream.on('end', async () => {
const completeStream = Buffer.concat(chunks)
SOCKET_SERVER.socket?.emit('tts-stream', {
chunk: completeStream,
audioId
})
try {
const duration = await this.getAudioDuration(outputPath)
TTS.em.emit('saved', duration)
} catch (e) {
LogHelper.title(this.name)
LogHelper.warning(`Failed to get audio duration: ${e}`)
}
try {
fs.unlinkSync(outputPath)
} catch (e) {
LogHelper.warning(`Failed to delete tmp audio file: ${e}`)
}
})
})
}
// TODO: support mood to control speed and pitch
PYTHON_TCP_CLIENT.emit('tts-synthesize', speech)
return {
audioFilePath: '',
duration: 500
}
}
}

View File

@ -16,6 +16,7 @@ interface Speech {
}
const PROVIDERS_MAP = {
[TTSProviders.Local]: TTSSynthesizers.Local,
[TTSProviders.GoogleCloudTTS]: TTSSynthesizers.GoogleCloudTTS,
[TTSProviders.WatsonTTS]: TTSSynthesizers.WatsonTTS,
[TTSProviders.AmazonPolly]: TTSSynthesizers.AmazonPolly,
@ -104,6 +105,11 @@ export default class TTS {
if (this.synthesizer) {
const result = await this.synthesizer.synthesize(speech.text)
// Support custom TTS providers such as the local synthesizer
if (result?.audioFilePath === '') {
return
}
if (!result) {
LogHelper.error(
'The TTS synthesizer failed to synthesize the speech as the result is null'
@ -120,7 +126,7 @@ export default class TTS {
duration
},
(confirmation: string) => {
if (confirmation === 'audio-received') {
if (confirmation === 'audio-received' && audioFilePath !== '') {
fs.unlinkSync(audioFilePath)
}
}

View File

@ -4,6 +4,7 @@ import type GoogleCloudTTSSynthesizer from '@/core/tts/synthesizers/google-cloud
import type WatsonTTSSynthesizer from '@/core/tts/synthesizers/watson-tts-synthesizer'
export enum TTSProviders {
Local = 'local',
AmazonPolly = 'amazon-polly',
GoogleCloudTTS = 'google-cloud-tts',
WatsonTTS = 'watson-tts',
@ -11,6 +12,7 @@ export enum TTSProviders {
}
export enum TTSSynthesizers {
Local = 'local-synthesizer',
AmazonPolly = 'amazon-polly-synthesizer',
GoogleCloudTTS = 'google-cloud-tts-synthesizer',
WatsonTTS = 'watson-tts-synthesizer',

View File

@ -41,18 +41,23 @@ class TCPServer:
ckpt_path=TTS_MODEL_PATH
)
text = 'Hello, I am Leon. How can I help you?'
speaker_ids = self.tts.hps.data.spk2id
output_path = os.path.join(TMP_PATH, 'output.wav')
speed = 1.0
self.tts.tts_to_file(text, speaker_ids['EN-Leon-V1'], output_path, speed=speed, quiet=True)
def init(self):
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use
self.tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.tcp_socket.bind((self.host, int(self.port)))
self.tcp_socket.listen()
try:
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use
self.tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.tcp_socket.bind((self.host, int(self.port)))
self.tcp_socket.listen()
except OSError as e:
# If the port is already in use, close the connection and retry
if 'Address already in use' in str(e):
self.log(f'Port {self.port} is already in use. Disconnecting client and retrying...')
if self.conn:
self.conn.close()
# Wait for a moment before retrying
time.sleep(1)
self.init()
else:
raise
while True:
# Flush buffered output to make it IPC friendly (readable on stdout)
@ -93,3 +98,39 @@ class TCPServer:
'spacyEntities': entities
}
}
def tts_synthesize(self, speech: str) -> dict:
"""
TODO:
- Implement one speaker per style (joyful, sad, angry, tired, etc.)
- Need to train a new model with default voice speaker and other speakers with different styles
- EN-Leon-Joyful-V1; EN-Leon-Sad-V1; etc.
"""
speaker_ids = self.tts.hps.data.spk2id
# Random file name to avoid conflicts
audio_id = f'{int(time.time())}_{os.urandom(2).hex()}'
output_file_name = f'{audio_id}.wav'
output_path = os.path.join(TMP_PATH, output_file_name)
speed = 0.88
formatted_speech = speech.replace(' - ', '.')
# formatted_speech = speech.replace(',', '.').replace('.', '...')
# TODO: should not wait to finish for streaming support
self.tts.tts_to_file(
formatted_speech,
speaker_ids['EN-Leon-V1'],
output_path=output_path,
speed=speed,
quiet=True,
format='wav',
stream=False
)
return {
'topic': 'tts-audio-streaming',
'data': {
'outputPath': output_path,
'audioId': audio_id
}
}

View File

@ -5,6 +5,8 @@ import torch.nn as nn
from tqdm import tqdm
import torch
import time
import wave
import os
from . import utils
from .models import SynthesizerTrn
@ -88,11 +90,13 @@ class TTS(nn.Module):
print(" > ===========================")
return texts
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False, stream=False):
tic = time.perf_counter()
self.log(f"Generating audio for:\n{text}")
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)
audio_list = []
if pbar:
tx = pbar(texts)
@ -131,21 +135,50 @@ class TTS(nn.Module):
length_scale=1. / speed,
)[0][0, 0].data.cpu().float().numpy()
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
#
audio_list.append(audio)
# Save audio data chunk by chunk
if stream:
# Convert audio to 16-bit PCM format
audio = (audio * 32767).astype(np.int16)
if not os.path.exists(output_path):
# If the file doesn't exist, create it and write the audio data to it
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 2 bytes for 16-bit PCM
wf.setframerate(self.hps.data.sampling_rate)
wf.writeframes(audio.tobytes())
else:
with wave.open(output_path, 'rb') as wf:
params = wf.getparams()
old_audio = np.frombuffer(wf.readframes(params.nframes), dtype=np.int16)
new_audio = np.concatenate([old_audio, audio])
with wave.open(output_path, 'wb') as wf:
wf.setparams(params)
wf.writeframes(new_audio.tobytes())
time.sleep(2)
self.log(f"Audio chunk saved")
else:
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
toc = time.perf_counter()
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
if output_path is None:
return audio
else:
if format:
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
# Concatenate audio segments and save the entire audio to file
if not stream:
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None:
return audio
else:
if format:
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
@staticmethod
def log(*args, **kwargs):