mirror of
https://github.com/leon-ai/leon.git
synced 2024-11-27 16:16:48 +03:00
fix(python tcp server): use CPU on macOS as a tmp fix against the memory leak for TTS
This commit is contained in:
parent
a8b3f7c9bd
commit
a8ece30ced
@ -11,11 +11,12 @@ import os
|
||||
from . import utils
|
||||
from .models import SynthesizerTrn
|
||||
from .split_utils import split_sentence
|
||||
from ..utils import is_macos
|
||||
|
||||
# torch.backends.cudnn.enabled = False
|
||||
|
||||
class TTS(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
language,
|
||||
device='auto',
|
||||
use_hf=True,
|
||||
@ -40,6 +41,13 @@ class TTS(nn.Module):
|
||||
if 'mps' in device:
|
||||
assert torch.backends.mps.is_available()
|
||||
|
||||
if is_macos():
|
||||
"""
|
||||
Temporary fix.
|
||||
Force CPU device for macOS because of the memory leak where cache does not want to clear up on MPS
|
||||
"""
|
||||
device = 'cpu'
|
||||
|
||||
self.log(f'Device: {device}')
|
||||
|
||||
hps = utils.get_hparams_from_file(config_path)
|
||||
@ -63,11 +71,11 @@ class TTS(nn.Module):
|
||||
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
self.hps = hps
|
||||
self.device = device
|
||||
|
||||
|
||||
# load state_dict
|
||||
checkpoint_dict = torch.load(ckpt_path, map_location=device)
|
||||
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
|
||||
|
||||
|
||||
language = language.split('_')[0]
|
||||
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
|
||||
|
||||
@ -172,11 +180,6 @@ class TTS(nn.Module):
|
||||
else:
|
||||
audio_list.append(audio)
|
||||
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
if self.device == 'mps':
|
||||
torch.mps.empty_cache()
|
||||
|
||||
toc = time.perf_counter()
|
||||
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
|
||||
|
||||
@ -184,6 +187,12 @@ class TTS(nn.Module):
|
||||
if not stream:
|
||||
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
||||
|
||||
del audio_list
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
if self.device == 'mps':
|
||||
torch.mps.empty_cache()
|
||||
|
||||
if output_path is None:
|
||||
return audio
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user