1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-08-17 06:00:33 +03:00

feat(scripts): TCP server setup add PyTorch nightly install

This commit is contained in:
louistiti 2024-05-17 12:20:34 +08:00
parent 85af31b614
commit e455a9d96b
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
5 changed files with 28 additions and 117 deletions

View File

@ -130,6 +130,30 @@ SPACY_MODELS.set('fr', {
const pipfileMtime = fs.statSync(pipfilePath).mtime
const hasDotVenv = fs.existsSync(dotVenvPath)
const { type: osType, cpuArchitecture } = SystemHelper.getInformation()
/**
* Install PyTorch nightly to support CUDA 12.4
* as it is required by the latest NVIDIA drivers for CUDA runtime APIs
*
* @see https://stackoverflow.com/a/76972265/1768162
*/
const installPytorch = async () => {
LogHelper.info('Installing PyTorch nightly with CUDA support...')
try {
await command(
'pipenv run pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124',
{
shell: true,
stdio: 'inherit'
}
)
LogHelper.success('PyTorch nightly with CUDA support installed')
} catch (e) {
LogHelper.error(
`Failed to install PyTorch nightly with CUDA support: ${e}`
)
process.exit(1)
}
}
const installPythonPackages = async () => {
LogHelper.info(`Installing Python packages from ${pipfilePath}.lock...`)
@ -177,6 +201,8 @@ SPACY_MODELS.set('fr', {
}
LogHelper.success('Python packages installed')
await installPytorch()
} catch (e) {
LogHelper.error(`Failed to install Python packages: ${e}`)

View File

@ -22,7 +22,8 @@ spacy = "==3.5.4"
geonamescache = "==1.6.0"
# TCP server; TTS
torch = "==1.12.1"
# PyTorch is installed via the setup script
# torch = "*"
# TTS
transformers = "==4.27.4"

View File

@ -32,7 +32,6 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
ja_bert = torch.zeros(768, len(phone))
else:
bert = get_bert(norm_text, word2ph, language_str, device)
print('bert', bert)
del word2ph
assert bert.shape[-1] == len(phone), phone

View File

@ -1,75 +0,0 @@
import os
import torch
from . import utils
from ..constants import TTS_MODEL_CONFIG_PATH, TTS_MODEL_PATH, IS_TTS_ENABLED
class TTS:
def __init__(self):
self.hyper_params = None
self.device = 'auto'
self.num_languages = None
self.num_tones = None
self.symbols = None
if not IS_TTS_ENABLED:
self.log('TTS is disabled')
return
self.log('Loading model...')
if not self.has_model_config():
self.log(f'Model config not found at {TTS_MODEL_CONFIG_PATH}')
return
if not self.is_model_downloaded():
self.log(f'Model not found at {TTS_MODEL_PATH}')
return
self.set_device()
self.hyper_params = utils.get_hparams_from_file(TTS_MODEL_CONFIG_PATH)
self.num_languages = self.hyper_params.num_languages
self.num_tones = self.hyper_params.num_tones
self.symbols = self.hyper_params.symbols
model = SynthesizerTrn(
len(self.symbols),
self.hyper_params.data.filter_length // 2 + 1,
self.hyper_params.train.segment_size // self.hyper_params.data.hop_length,
n_speakers=self.hyper_params.data.n_speakers,
num_tones=self.num_tones,
num_languages=self.num_languages,
**self.hyper_params.model,
).to(self.device)
model.eval()
self.model = model
self.log('Model loaded')
def set_device(self):
if self.device == 'auto':
self.device = 'cpu'
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.log('GPU not available. CUDA is not installed?')
if torch.backends.mps.is_available():
self.device = 'mps'
if 'cuda' in self.device:
assert torch.cuda.is_available()
self.log(f'Device: {self.device}')
@staticmethod
def is_model_downloaded():
return os.path.exists(TTS_MODEL_PATH)
@staticmethod
def has_model_config():
return os.path.exists(TTS_MODEL_CONFIG_PATH)
@staticmethod
def log(*args, **kwargs):
print('[TTS]', *args, **kwargs)

View File

@ -1,40 +0,0 @@
import json
def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()