diff --git a/scripts/setup/setup-python-dev-env.js b/scripts/setup/setup-python-dev-env.js index 7a4ee872..e804f196 100644 --- a/scripts/setup/setup-python-dev-env.js +++ b/scripts/setup/setup-python-dev-env.js @@ -130,6 +130,30 @@ SPACY_MODELS.set('fr', { const pipfileMtime = fs.statSync(pipfilePath).mtime const hasDotVenv = fs.existsSync(dotVenvPath) const { type: osType, cpuArchitecture } = SystemHelper.getInformation() + /** + * Install PyTorch nightly to support CUDA 12.4 + * as it is required by the latest NVIDIA drivers for CUDA runtime APIs + * + * @see https://stackoverflow.com/a/76972265/1768162 + */ + const installPytorch = async () => { + LogHelper.info('Installing PyTorch nightly with CUDA support...') + try { + await command( + 'pipenv run pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124', + { + shell: true, + stdio: 'inherit' + } + ) + LogHelper.success('PyTorch nightly with CUDA support installed') + } catch (e) { + LogHelper.error( + `Failed to install PyTorch nightly with CUDA support: ${e}` + ) + process.exit(1) + } + } const installPythonPackages = async () => { LogHelper.info(`Installing Python packages from ${pipfilePath}.lock...`) @@ -177,6 +201,8 @@ SPACY_MODELS.set('fr', { } LogHelper.success('Python packages installed') + + await installPytorch() } catch (e) { LogHelper.error(`Failed to install Python packages: ${e}`) diff --git a/tcp_server/src/Pipfile b/tcp_server/src/Pipfile index 07a46483..0b9a6585 100644 --- a/tcp_server/src/Pipfile +++ b/tcp_server/src/Pipfile @@ -22,7 +22,8 @@ spacy = "==3.5.4" geonamescache = "==1.6.0" # TCP server; TTS -torch = "==1.12.1" +# PyTorch is installed via the setup script +# torch = "*" # TTS transformers = "==4.27.4" diff --git a/tcp_server/src/lib/tts/utils.py b/tcp_server/src/lib/tts/utils.py index fb33f910..f69cbb98 100644 --- a/tcp_server/src/lib/tts/utils.py +++ b/tcp_server/src/lib/tts/utils.py @@ -32,7 +32,6 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None): ja_bert = torch.zeros(768, len(phone)) else: bert = get_bert(norm_text, word2ph, language_str, device) - print('bert', bert) del word2ph assert bert.shape[-1] == len(phone), phone diff --git a/tcp_server/src/lib/tts_to_delete/tts.py b/tcp_server/src/lib/tts_to_delete/tts.py deleted file mode 100644 index 8fb915b5..00000000 --- a/tcp_server/src/lib/tts_to_delete/tts.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import torch - -from . import utils -from ..constants import TTS_MODEL_CONFIG_PATH, TTS_MODEL_PATH, IS_TTS_ENABLED - -class TTS: - def __init__(self): - self.hyper_params = None - self.device = 'auto' - self.num_languages = None - self.num_tones = None - self.symbols = None - - if not IS_TTS_ENABLED: - self.log('TTS is disabled') - return - - self.log('Loading model...') - - if not self.has_model_config(): - self.log(f'Model config not found at {TTS_MODEL_CONFIG_PATH}') - return - - if not self.is_model_downloaded(): - self.log(f'Model not found at {TTS_MODEL_PATH}') - return - - self.set_device() - self.hyper_params = utils.get_hparams_from_file(TTS_MODEL_CONFIG_PATH) - self.num_languages = self.hyper_params.num_languages - self.num_tones = self.hyper_params.num_tones - self.symbols = self.hyper_params.symbols - - model = SynthesizerTrn( - len(self.symbols), - self.hyper_params.data.filter_length // 2 + 1, - self.hyper_params.train.segment_size // self.hyper_params.data.hop_length, - n_speakers=self.hyper_params.data.n_speakers, - num_tones=self.num_tones, - num_languages=self.num_languages, - **self.hyper_params.model, - ).to(self.device) - model.eval() - self.model = model - - self.log('Model loaded') - - def set_device(self): - if self.device == 'auto': - self.device = 'cpu' - - if torch.cuda.is_available(): - self.device = 'cuda' - else: - self.log('GPU not available. CUDA is not installed?') - - if torch.backends.mps.is_available(): - self.device = 'mps' - if 'cuda' in self.device: - assert torch.cuda.is_available() - - self.log(f'Device: {self.device}') - - @staticmethod - def is_model_downloaded(): - return os.path.exists(TTS_MODEL_PATH) - - @staticmethod - def has_model_config(): - return os.path.exists(TTS_MODEL_CONFIG_PATH) - - @staticmethod - def log(*args, **kwargs): - print('[TTS]', *args, **kwargs) diff --git a/tcp_server/src/lib/tts_to_delete/utils.py b/tcp_server/src/lib/tts_to_delete/utils.py deleted file mode 100644 index 6f8263d8..00000000 --- a/tcp_server/src/lib/tts_to_delete/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import json - -def get_hparams_from_file(config_path): - with open(config_path, "r", encoding="utf-8") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - return hparams - -class HParams: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if type(v) == dict: - v = HParams(**v) - self[k] = v - - def keys(self): - return self.__dict__.keys() - - def items(self): - return self.__dict__.items() - - def values(self): - return self.__dict__.values() - - def __len__(self): - return len(self.__dict__) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - return setattr(self, key, value) - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - return self.__dict__.__repr__()