1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-08-17 06:00:33 +03:00

feat(scripts): download and install TTS BERT model files for the Python TCP server env setup

This commit is contained in:
louistiti 2024-06-06 07:56:03 +08:00
parent 6df34d48ee
commit 7c9487b62b
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
2 changed files with 93 additions and 15 deletions

View File

@ -18,7 +18,9 @@ import {
PYTHON_TCP_SERVER_SRC_ASR_MODEL_PATH_FOR_GPU,
PYTHON_TCP_SERVER_SRC_ASR_MODEL_PATH_FOR_CPU,
PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH,
PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH
PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH,
PYTHON_TCP_SERVER_TTS_BERT_FRENCH_MODEL_HF_PREFIX_DOWNLOAD_URL,
PYTHON_TCP_SERVER_TTS_BERT_BASE_MODEL_HF_PREFIX_DOWNLOAD_URL
} from '@/constants'
import { CPUArchitectures, OSTypes } from '@/types'
import { LogHelper } from '@/helpers/log-helper'
@ -419,6 +421,62 @@ SPACY_MODELS.set('fr', {
process.exit(1)
}
}
const installTTSBERTFrenchModel = async () => {
try {
LogHelper.info('Installing TTS BERT French model...')
for (const modelFile of TTS_BERT_FRENCH_MODEL_FILES) {
const modelInstallationFileURL = `${PYTHON_TCP_SERVER_TTS_BERT_FRENCH_MODEL_HF_PREFIX_DOWNLOAD_URL}/${modelFile}?download=true`
const destPath = fs.createWriteStream(
path.join(PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH, modelFile)
)
LogHelper.info(`Downloading ${modelFile}...`)
const response = await FileHelper.downloadFile(
modelInstallationFileURL,
'stream'
)
response.data.pipe(destPath)
await stream.promises.finished(destPath)
LogHelper.success(`${modelFile} downloaded at ${destPath.path}`)
}
LogHelper.success('TTS BERT French model installed')
} catch (e) {
LogHelper.error(`Failed to install TTS BERT French model: ${e}`)
process.exit(1)
}
}
const installTTSBERTBaseModel = async () => {
try {
LogHelper.info('Installing TTS BERT base model...')
for (const modelFile of TTS_BERT_BASE_MODEL_FILES) {
const modelInstallationFileURL = `${PYTHON_TCP_SERVER_TTS_BERT_BASE_MODEL_HF_PREFIX_DOWNLOAD_URL}/${modelFile}?download=true`
const destPath = fs.createWriteStream(
path.join(PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH, modelFile)
)
LogHelper.info(`Downloading ${modelFile}...`)
const response = await FileHelper.downloadFile(
modelInstallationFileURL,
'stream'
)
response.data.pipe(destPath)
await stream.promises.finished(destPath)
LogHelper.success(`${modelFile} downloaded at ${destPath.path}`)
}
LogHelper.success('TTS BERT base model installed')
} catch (e) {
LogHelper.error(`Failed to install TTS BERT base model: ${e}`)
process.exit(1)
}
}
LogHelper.info('Checking whether all spaCy models are installed...')
@ -444,28 +502,40 @@ SPACY_MODELS.set('fr', {
await installSpacyModels()
}
LogHelper.info('Checking whether TTS BERT base language model files are downloaded...')
const areTTSBERTBaseFilesDownloaded = fs.existsSync(path.join(
PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH,
TTS_BERT_BASE_MODEL_FILES[TTS_BERT_BASE_MODEL_FILES.length - 1]
))
LogHelper.info(
'Checking whether TTS BERT base language model files are downloaded...'
)
const areTTSBERTBaseFilesDownloaded = fs.existsSync(
path.join(
PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH,
TTS_BERT_BASE_MODEL_FILES[TTS_BERT_BASE_MODEL_FILES.length - 1]
)
)
if (!areTTSBERTBaseFilesDownloaded) {
LogHelper.info('TTS BERT base language model files not downloaded')
// TODO: download files
await installTTSBERTBaseModel()
} else {
LogHelper.success('TTS BERT base language model files are already downloaded')
LogHelper.success(
'TTS BERT base language model files are already downloaded'
)
}
LogHelper.info('Checking whether TTS BERT French language model files are downloaded...')
const areTTSBERTFrenchFilesDownloaded = fs.existsSync(path.join(
PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH,
TTS_BERT_FRENCH_MODEL_FILES[TTS_BERT_FRENCH_MODEL_FILES.length - 1]
))
LogHelper.info(
'Checking whether TTS BERT French language model files are downloaded...'
)
const areTTSBERTFrenchFilesDownloaded = fs.existsSync(
path.join(
PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH,
TTS_BERT_FRENCH_MODEL_FILES[TTS_BERT_FRENCH_MODEL_FILES.length - 1]
)
)
if (!areTTSBERTFrenchFilesDownloaded) {
LogHelper.info('TTS BERT French language model files not downloaded')
// TODO: download files
await installTTSBERTFrenchModel()
} else {
LogHelper.success('TTS BERT French language model files are already downloaded')
LogHelper.success(
'TTS BERT French language model files are already downloaded'
)
}
LogHelper.info('Checking whether the TTS model is installed...')

View File

@ -92,6 +92,14 @@ export const PYTHON_TCP_SERVER_ASR_MODEL_CPU_HF_PREFIX_DOWNLOAD_URL =
NetworkHelper.setHuggingFaceURL(
'https://huggingface.co/Systran/faster-whisper-medium/resolve/main'
)
export const PYTHON_TCP_SERVER_TTS_BERT_FRENCH_MODEL_HF_PREFIX_DOWNLOAD_URL =
NetworkHelper.setHuggingFaceURL(
'https://huggingface.co/dbmdz/bert-base-french-europeana-cased/resolve/main'
)
export const PYTHON_TCP_SERVER_TTS_BERT_BASE_MODEL_HF_PREFIX_DOWNLOAD_URL =
NetworkHelper.setHuggingFaceURL(
'https://huggingface.co/google-bert/bert-base-uncased/resolve/main'
)
const NODEJS_BRIDGE_VERSION_FILE_PATH = path.join(
NODEJS_BRIDGE_SRC_PATH,