mirror of
https://github.com/leon-ai/leon.git
synced 2024-10-26 18:18:46 +03:00
feat(python tcp server): force TTS BERT local files
This commit is contained in:
parent
7c9487b62b
commit
4ce470d3e7
3
.gitignore
vendored
3
.gitignore
vendored
@ -25,6 +25,9 @@ leon.json
|
||||
bridges/python/src/Pipfile.lock
|
||||
tcp_server/src/Pipfile.lock
|
||||
tcp_server/src/lib/tts/models/*.pth
|
||||
tcp_server/src/lib/tts/models/**/*.bin
|
||||
tcp_server/src/lib/tts/models/**/*.json
|
||||
tcp_server/src/lib/tts/models/**/*.txt
|
||||
tcp_server/src/lib/asr/models/**/*.bin
|
||||
!tcp_server/**/.gitkeep
|
||||
!bridges/python/**/.gitkeep
|
||||
|
@ -17,9 +17,9 @@ import {
|
||||
PYTHON_TCP_SERVER_ASR_MODEL_GPU_HF_PREFIX_DOWNLOAD_URL,
|
||||
PYTHON_TCP_SERVER_SRC_ASR_MODEL_PATH_FOR_GPU,
|
||||
PYTHON_TCP_SERVER_SRC_ASR_MODEL_PATH_FOR_CPU,
|
||||
PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH,
|
||||
// PYTHON_TCP_SERVER_SRC_TTS_BERT_FRENCH_DIR_PATH,
|
||||
PYTHON_TCP_SERVER_SRC_TTS_BERT_BASE_DIR_PATH,
|
||||
PYTHON_TCP_SERVER_TTS_BERT_FRENCH_MODEL_HF_PREFIX_DOWNLOAD_URL,
|
||||
// PYTHON_TCP_SERVER_TTS_BERT_FRENCH_MODEL_HF_PREFIX_DOWNLOAD_URL,
|
||||
PYTHON_TCP_SERVER_TTS_BERT_BASE_MODEL_HF_PREFIX_DOWNLOAD_URL
|
||||
} from '@/constants'
|
||||
import { CPUArchitectures, OSTypes } from '@/types'
|
||||
@ -67,12 +67,12 @@ const ASR_CPU_MODEL_FILES = [
|
||||
'tokenizer.json',
|
||||
'vocabulary.txt'
|
||||
]
|
||||
const TTS_BERT_FRENCH_MODEL_FILES = [
|
||||
'pytorch_model.bin',
|
||||
/*const TTS_BERT_FRENCH_MODEL_FILES = [
|
||||
'pytorch_model.bin', // Not needed? Compare with HF auto download in ~/.cache/huggingface/hub...
|
||||
'config.json',
|
||||
'vocab.txt',
|
||||
'tokenizer_config.json'
|
||||
]
|
||||
]*/
|
||||
const TTS_BERT_BASE_MODEL_FILES = [
|
||||
'pytorch_model.bin',
|
||||
'config.json',
|
||||
@ -421,7 +421,7 @@ SPACY_MODELS.set('fr', {
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
const installTTSBERTFrenchModel = async () => {
|
||||
/*const installTTSBERTFrenchModel = async () => {
|
||||
try {
|
||||
LogHelper.info('Installing TTS BERT French model...')
|
||||
|
||||
@ -448,7 +448,7 @@ SPACY_MODELS.set('fr', {
|
||||
LogHelper.error(`Failed to install TTS BERT French model: ${e}`)
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
}*/
|
||||
const installTTSBERTBaseModel = async () => {
|
||||
try {
|
||||
LogHelper.info('Installing TTS BERT base model...')
|
||||
@ -520,7 +520,8 @@ SPACY_MODELS.set('fr', {
|
||||
)
|
||||
}
|
||||
|
||||
LogHelper.info(
|
||||
// TODO: later when multiple languages are supported
|
||||
/*LogHelper.info(
|
||||
'Checking whether TTS BERT French language model files are downloaded...'
|
||||
)
|
||||
const areTTSBERTFrenchFilesDownloaded = fs.existsSync(
|
||||
@ -536,7 +537,7 @@ SPACY_MODELS.set('fr', {
|
||||
LogHelper.success(
|
||||
'TTS BERT French language model files are already downloaded'
|
||||
)
|
||||
}
|
||||
}*/
|
||||
|
||||
LogHelper.info('Checking whether the TTS model is installed...')
|
||||
const isTTSModelInstalled = fs.existsSync(
|
||||
|
@ -18,6 +18,8 @@ TTS_MODEL_NAME = f'EN-Leon-{TTS_MODEL_VERSION}-G_{TTS_MODEL_ITERATION}'
|
||||
TTS_MODEL_FILE_NAME = f'{TTS_MODEL_NAME}.pth'
|
||||
TTS_LIB_PATH = os.path.join(LIB_PATH, 'tts')
|
||||
TTS_MODEL_FOLDER_PATH = os.path.join(TTS_LIB_PATH, 'models')
|
||||
TTS_BERT_FRENCH_MODEL_DIR_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, 'bert-case-french-europeana-cased')
|
||||
TTS_BERT_BASE_MODEL_DIR_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, 'bert-base-uncased')
|
||||
TTS_MODEL_CONFIG_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, 'config.json')
|
||||
TTS_MODEL_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, TTS_MODEL_FILE_NAME)
|
||||
IS_TTS_ENABLED = os.environ.get('LEON_TTS', 'true') == 'true'
|
||||
|
@ -22,9 +22,8 @@ def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
|
||||
|
||||
def get_bert(norm_text, word2ph, language, device):
|
||||
from .english_bert import get_bert_feature as en_bert
|
||||
from .french_bert import get_bert_feature as fr_bert
|
||||
# from .french_bert import get_bert_feature as fr_bert
|
||||
|
||||
lang_bert_func_map = {"EN": en_bert,
|
||||
'FR': fr_bert}
|
||||
lang_bert_func_map = {"EN": en_bert}
|
||||
bert = lang_bert_func_map[language](norm_text, word2ph, device)
|
||||
return bert
|
||||
|
@ -1,9 +1,10 @@
|
||||
from . import english, french
|
||||
from . import english
|
||||
from . import cleaned_text_to_sequence
|
||||
import copy
|
||||
|
||||
language_module_map = {"EN": english,
|
||||
'FR': french}
|
||||
# language_module_map = {"EN": english,
|
||||
# 'FR': french}
|
||||
language_module_map = {"EN": english}
|
||||
|
||||
|
||||
def clean_text(text, language):
|
||||
|
@ -4,6 +4,9 @@ import re
|
||||
from g2p_en import G2p
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
from lib.constants import TTS_BERT_BASE_MODEL_DIR_PATH
|
||||
|
||||
from . import symbols
|
||||
|
||||
from .english_utils.abbreviations import expand_abbreviations
|
||||
@ -192,8 +195,12 @@ def text_normalize(text):
|
||||
text = expand_abbreviations(text)
|
||||
return text
|
||||
|
||||
model_id = 'bert-base-uncased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
load_model_params = {
|
||||
"pretrained_model_name_or_path": TTS_BERT_BASE_MODEL_DIR_PATH,
|
||||
"local_files_only": True
|
||||
}
|
||||
tokenizer = AutoTokenizer.from_pretrained(**load_model_params)
|
||||
|
||||
def g2p_old(text):
|
||||
tokenized = tokenizer.tokenize(text)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
@ -2,9 +2,13 @@ import torch
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
import sys
|
||||
|
||||
from lib.constants import TTS_BERT_BASE_MODEL_DIR_PATH
|
||||
|
||||
model_id = 'bert-base-uncased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
load_model_params = {
|
||||
"pretrained_model_name_or_path": TTS_BERT_BASE_MODEL_DIR_PATH,
|
||||
"local_files_only": True
|
||||
}
|
||||
tokenizer = AutoTokenizer.from_pretrained(**load_model_params)
|
||||
model = None
|
||||
|
||||
def get_bert_feature(text, word2ph, device=None):
|
||||
@ -18,7 +22,7 @@ def get_bert_feature(text, word2ph, device=None):
|
||||
if not device:
|
||||
device = "cuda"
|
||||
if model is None:
|
||||
model = AutoModelForMaskedLM.from_pretrained(model_id).to(
|
||||
model = AutoModelForMaskedLM.from_pretrained(**load_model_params).to(
|
||||
device
|
||||
)
|
||||
with torch.no_grad():
|
||||
|
@ -1,5 +1,7 @@
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lib.constants import TTS_BERT_FRENCH_MODEL_DIR_PATH
|
||||
|
||||
from .fr_phonemizer import cleaner as fr_cleaner
|
||||
from .fr_phonemizer import fr_to_ipa
|
||||
|
||||
@ -16,8 +18,11 @@ def text_normalize(text):
|
||||
text = fr_cleaner.french_cleaners(text)
|
||||
return text
|
||||
|
||||
model_id = 'dbmdz/bert-base-french-europeana-cased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
load_model_params = {
|
||||
"pretrained_model_name_or_path": 'dbmdz/bert-base-french-europeana-cased',
|
||||
"local_files_only": True
|
||||
}
|
||||
tokenizer = AutoTokenizer.from_pretrained(**load_model_params)
|
||||
|
||||
def g2p(text, pad_start_end=True, tokenized=None):
|
||||
if tokenized is None:
|
||||
|
@ -2,8 +2,13 @@ import torch
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
import sys
|
||||
|
||||
model_id = 'dbmdz/bert-base-french-europeana-cased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
from lib.constants import TTS_BERT_FRENCH_MODEL_DIR_PATH
|
||||
|
||||
load_model_params = {
|
||||
"pretrained_model_name_or_path": TTS_BERT_FRENCH_MODEL_DIR_PATH,
|
||||
"local_files_only": True
|
||||
}
|
||||
tokenizer = AutoTokenizer.from_pretrained(**load_model_params)
|
||||
model = None
|
||||
|
||||
def get_bert_feature(text, word2ph, device=None):
|
||||
@ -17,7 +22,7 @@ def get_bert_feature(text, word2ph, device=None):
|
||||
if not device:
|
||||
device = "cuda"
|
||||
if model is None:
|
||||
model = AutoModelForMaskedLM.from_pretrained(model_id).to(
|
||||
model = AutoModelForMaskedLM.from_pretrained(**load_model_params).to(
|
||||
device
|
||||
)
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user