mirror of
https://github.com/leon-ai/leon.git
synced 2024-07-14 21:40:34 +03:00
feat(server): complete first version of the new TTS engine
This commit is contained in:
parent
9a065175ca
commit
f776913ca3
@ -16,6 +16,7 @@ export default class Client {
|
||||
this.chatbot = new Chatbot()
|
||||
this._recorder = {}
|
||||
this._suggestions = []
|
||||
// this._ttsAudioContextes = {}
|
||||
}
|
||||
|
||||
set input(newInput) {
|
||||
@ -90,6 +91,26 @@ export default class Client {
|
||||
this.chatbot.createBubble('leon', data)
|
||||
})
|
||||
|
||||
/**
|
||||
* Only used for "local" TTS provider as a PoC for now.
|
||||
* Target to do a better implementation in the future
|
||||
* with streaming support
|
||||
*/
|
||||
this.socket.on('tts-stream', (data) => {
|
||||
// const { audioId, chunk } = data
|
||||
const { chunk } = data
|
||||
const ctx = new AudioContext()
|
||||
// this._ttsAudioContextes[audioId] = ctx
|
||||
|
||||
const source = ctx.createBufferSource()
|
||||
ctx.decodeAudioData(chunk, (buffer) => {
|
||||
source.buffer = buffer
|
||||
|
||||
source.connect(ctx.destination)
|
||||
source.start(0)
|
||||
})
|
||||
})
|
||||
|
||||
this.socket.on('audio-forwarded', (data, cb) => {
|
||||
const ctx = new AudioContext()
|
||||
const source = ctx.createBufferSource()
|
||||
|
@ -19,6 +19,7 @@ import {
|
||||
} from '@/constants'
|
||||
import { OSTypes } from '@/types'
|
||||
import { LogHelper } from '@/helpers/log-helper'
|
||||
import { LoaderHelper } from '@/helpers/loader-helper'
|
||||
import { SystemHelper } from '@/helpers/system-helper'
|
||||
|
||||
/**
|
||||
@ -56,6 +57,8 @@ BUILD_TARGETS.set('tcp-server', {
|
||||
dotVenvPath: path.join(PYTHON_TCP_SERVER_SRC_PATH, '.venv')
|
||||
})
|
||||
;(async () => {
|
||||
LoaderHelper.start()
|
||||
|
||||
const { argv } = process
|
||||
const givenBuildTarget = argv[2].toLowerCase()
|
||||
|
||||
|
83
server/src/core/tts/synthesizers/local-synthesizer.ts
Normal file
83
server/src/core/tts/synthesizers/local-synthesizer.ts
Normal file
@ -0,0 +1,83 @@
|
||||
import fs from 'node:fs'
|
||||
|
||||
import type { LongLanguageCode } from '@/types'
|
||||
import type { SynthesizeResult } from '@/core/tts/types'
|
||||
import { LANG } from '@/constants'
|
||||
import { PYTHON_TCP_CLIENT, SOCKET_SERVER, TTS } from '@/core'
|
||||
import { TTSSynthesizerBase } from '@/core/tts/tts-synthesizer-base'
|
||||
import { LogHelper } from '@/helpers/log-helper'
|
||||
|
||||
interface ChunkData {
|
||||
outputPath: string
|
||||
audioId: string
|
||||
}
|
||||
|
||||
export default class LocalSynthesizer extends TTSSynthesizerBase {
|
||||
protected readonly name = 'Local TTS Synthesizer'
|
||||
protected readonly lang = LANG as LongLanguageCode
|
||||
|
||||
constructor(lang: LongLanguageCode) {
|
||||
super()
|
||||
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.success('New instance')
|
||||
|
||||
try {
|
||||
this.lang = lang
|
||||
|
||||
LogHelper.success('Synthesizer initialized')
|
||||
} catch (e) {
|
||||
LogHelper.error(`${this.name} - Failed to initialize: ${e}`)
|
||||
}
|
||||
}
|
||||
|
||||
public async synthesize(speech: string): Promise<SynthesizeResult | null> {
|
||||
const eventName = 'tts-receiving-stream'
|
||||
const eventHasListeners = PYTHON_TCP_CLIENT.ee.listenerCount(eventName) > 0
|
||||
|
||||
if (!eventHasListeners) {
|
||||
PYTHON_TCP_CLIENT.ee.on('tts-audio-streaming', (data: ChunkData) => {
|
||||
/**
|
||||
* Send audio stream chunk by chunk to the client as long as
|
||||
* the temporary file is being written from the TCP server
|
||||
*/
|
||||
const { outputPath, audioId } = data
|
||||
const stream = fs.createReadStream(outputPath)
|
||||
const chunks: Buffer[] = []
|
||||
stream.on('data', (chunk: Buffer) => {
|
||||
chunks.push(chunk)
|
||||
// SOCKET_SERVER.socket?.emit('tts-stream', { chunk, audioId })
|
||||
})
|
||||
stream.on('end', async () => {
|
||||
const completeStream = Buffer.concat(chunks)
|
||||
|
||||
SOCKET_SERVER.socket?.emit('tts-stream', {
|
||||
chunk: completeStream,
|
||||
audioId
|
||||
})
|
||||
|
||||
try {
|
||||
const duration = await this.getAudioDuration(outputPath)
|
||||
TTS.em.emit('saved', duration)
|
||||
} catch (e) {
|
||||
LogHelper.title(this.name)
|
||||
LogHelper.warning(`Failed to get audio duration: ${e}`)
|
||||
}
|
||||
try {
|
||||
fs.unlinkSync(outputPath)
|
||||
} catch (e) {
|
||||
LogHelper.warning(`Failed to delete tmp audio file: ${e}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: support mood to control speed and pitch
|
||||
PYTHON_TCP_CLIENT.emit('tts-synthesize', speech)
|
||||
|
||||
return {
|
||||
audioFilePath: '',
|
||||
duration: 500
|
||||
}
|
||||
}
|
||||
}
|
@ -16,6 +16,7 @@ interface Speech {
|
||||
}
|
||||
|
||||
const PROVIDERS_MAP = {
|
||||
[TTSProviders.Local]: TTSSynthesizers.Local,
|
||||
[TTSProviders.GoogleCloudTTS]: TTSSynthesizers.GoogleCloudTTS,
|
||||
[TTSProviders.WatsonTTS]: TTSSynthesizers.WatsonTTS,
|
||||
[TTSProviders.AmazonPolly]: TTSSynthesizers.AmazonPolly,
|
||||
@ -104,6 +105,11 @@ export default class TTS {
|
||||
if (this.synthesizer) {
|
||||
const result = await this.synthesizer.synthesize(speech.text)
|
||||
|
||||
// Support custom TTS providers such as the local synthesizer
|
||||
if (result?.audioFilePath === '') {
|
||||
return
|
||||
}
|
||||
|
||||
if (!result) {
|
||||
LogHelper.error(
|
||||
'The TTS synthesizer failed to synthesize the speech as the result is null'
|
||||
@ -120,7 +126,7 @@ export default class TTS {
|
||||
duration
|
||||
},
|
||||
(confirmation: string) => {
|
||||
if (confirmation === 'audio-received') {
|
||||
if (confirmation === 'audio-received' && audioFilePath !== '') {
|
||||
fs.unlinkSync(audioFilePath)
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import type GoogleCloudTTSSynthesizer from '@/core/tts/synthesizers/google-cloud
|
||||
import type WatsonTTSSynthesizer from '@/core/tts/synthesizers/watson-tts-synthesizer'
|
||||
|
||||
export enum TTSProviders {
|
||||
Local = 'local',
|
||||
AmazonPolly = 'amazon-polly',
|
||||
GoogleCloudTTS = 'google-cloud-tts',
|
||||
WatsonTTS = 'watson-tts',
|
||||
@ -11,6 +12,7 @@ export enum TTSProviders {
|
||||
}
|
||||
|
||||
export enum TTSSynthesizers {
|
||||
Local = 'local-synthesizer',
|
||||
AmazonPolly = 'amazon-polly-synthesizer',
|
||||
GoogleCloudTTS = 'google-cloud-tts-synthesizer',
|
||||
WatsonTTS = 'watson-tts-synthesizer',
|
||||
|
@ -41,18 +41,23 @@ class TCPServer:
|
||||
ckpt_path=TTS_MODEL_PATH
|
||||
)
|
||||
|
||||
text = 'Hello, I am Leon. How can I help you?'
|
||||
speaker_ids = self.tts.hps.data.spk2id
|
||||
output_path = os.path.join(TMP_PATH, 'output.wav')
|
||||
speed = 1.0
|
||||
|
||||
self.tts.tts_to_file(text, speaker_ids['EN-Leon-V1'], output_path, speed=speed, quiet=True)
|
||||
|
||||
def init(self):
|
||||
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use
|
||||
self.tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.tcp_socket.bind((self.host, int(self.port)))
|
||||
self.tcp_socket.listen()
|
||||
try:
|
||||
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use
|
||||
self.tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.tcp_socket.bind((self.host, int(self.port)))
|
||||
self.tcp_socket.listen()
|
||||
except OSError as e:
|
||||
# If the port is already in use, close the connection and retry
|
||||
if 'Address already in use' in str(e):
|
||||
self.log(f'Port {self.port} is already in use. Disconnecting client and retrying...')
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
# Wait for a moment before retrying
|
||||
time.sleep(1)
|
||||
self.init()
|
||||
else:
|
||||
raise
|
||||
|
||||
while True:
|
||||
# Flush buffered output to make it IPC friendly (readable on stdout)
|
||||
@ -93,3 +98,39 @@ class TCPServer:
|
||||
'spacyEntities': entities
|
||||
}
|
||||
}
|
||||
|
||||
def tts_synthesize(self, speech: str) -> dict:
|
||||
"""
|
||||
TODO:
|
||||
- Implement one speaker per style (joyful, sad, angry, tired, etc.)
|
||||
- Need to train a new model with default voice speaker and other speakers with different styles
|
||||
- EN-Leon-Joyful-V1; EN-Leon-Sad-V1; etc.
|
||||
"""
|
||||
speaker_ids = self.tts.hps.data.spk2id
|
||||
# Random file name to avoid conflicts
|
||||
audio_id = f'{int(time.time())}_{os.urandom(2).hex()}'
|
||||
output_file_name = f'{audio_id}.wav'
|
||||
output_path = os.path.join(TMP_PATH, output_file_name)
|
||||
speed = 0.88
|
||||
|
||||
formatted_speech = speech.replace(' - ', '.')
|
||||
# formatted_speech = speech.replace(',', '.').replace('.', '...')
|
||||
|
||||
# TODO: should not wait to finish for streaming support
|
||||
self.tts.tts_to_file(
|
||||
formatted_speech,
|
||||
speaker_ids['EN-Leon-V1'],
|
||||
output_path=output_path,
|
||||
speed=speed,
|
||||
quiet=True,
|
||||
format='wav',
|
||||
stream=False
|
||||
)
|
||||
|
||||
return {
|
||||
'topic': 'tts-audio-streaming',
|
||||
'data': {
|
||||
'outputPath': output_path,
|
||||
'audioId': audio_id
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import time
|
||||
import wave
|
||||
import os
|
||||
|
||||
from . import utils
|
||||
from .models import SynthesizerTrn
|
||||
@ -88,11 +90,13 @@ class TTS(nn.Module):
|
||||
print(" > ===========================")
|
||||
return texts
|
||||
|
||||
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
|
||||
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False, stream=False):
|
||||
tic = time.perf_counter()
|
||||
self.log(f"Generating audio for:\n{text}")
|
||||
language = self.language
|
||||
|
||||
texts = self.split_sentences_into_pieces(text, language, quiet)
|
||||
|
||||
audio_list = []
|
||||
if pbar:
|
||||
tx = pbar(texts)
|
||||
@ -131,21 +135,50 @@ class TTS(nn.Module):
|
||||
length_scale=1. / speed,
|
||||
)[0][0, 0].data.cpu().float().numpy()
|
||||
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
|
||||
#
|
||||
audio_list.append(audio)
|
||||
|
||||
# Save audio data chunk by chunk
|
||||
if stream:
|
||||
# Convert audio to 16-bit PCM format
|
||||
audio = (audio * 32767).astype(np.int16)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
# If the file doesn't exist, create it and write the audio data to it
|
||||
with wave.open(output_path, 'wb') as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 2 bytes for 16-bit PCM
|
||||
wf.setframerate(self.hps.data.sampling_rate)
|
||||
wf.writeframes(audio.tobytes())
|
||||
else:
|
||||
with wave.open(output_path, 'rb') as wf:
|
||||
params = wf.getparams()
|
||||
old_audio = np.frombuffer(wf.readframes(params.nframes), dtype=np.int16)
|
||||
new_audio = np.concatenate([old_audio, audio])
|
||||
|
||||
with wave.open(output_path, 'wb') as wf:
|
||||
wf.setparams(params)
|
||||
wf.writeframes(new_audio.tobytes())
|
||||
|
||||
time.sleep(2)
|
||||
self.log(f"Audio chunk saved")
|
||||
else:
|
||||
audio_list.append(audio)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
||||
|
||||
toc = time.perf_counter()
|
||||
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
|
||||
|
||||
if output_path is None:
|
||||
return audio
|
||||
else:
|
||||
if format:
|
||||
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
|
||||
else:
|
||||
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
||||
# Concatenate audio segments and save the entire audio to file
|
||||
if not stream:
|
||||
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
||||
|
||||
if output_path is None:
|
||||
return audio
|
||||
else:
|
||||
if format:
|
||||
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
|
||||
else:
|
||||
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
||||
|
||||
@staticmethod
|
||||
def log(*args, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user