1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-09-11 10:25:40 +03:00

fix(python tcp server): use CPU on macOS as a tmp fix against the memory leak for TTS

This commit is contained in:
louistiti 2024-06-20 22:24:25 +08:00
parent a8b3f7c9bd
commit a8ece30ced

View File

@ -11,11 +11,12 @@ import os
from . import utils
from .models import SynthesizerTrn
from .split_utils import split_sentence
from ..utils import is_macos
# torch.backends.cudnn.enabled = False
class TTS(nn.Module):
def __init__(self,
def __init__(self,
language,
device='auto',
use_hf=True,
@ -40,6 +41,13 @@ class TTS(nn.Module):
if 'mps' in device:
assert torch.backends.mps.is_available()
if is_macos():
"""
Temporary fix.
Force CPU device for macOS because of the memory leak where cache does not want to clear up on MPS
"""
device = 'cpu'
self.log(f'Device: {device}')
hps = utils.get_hparams_from_file(config_path)
@ -63,11 +71,11 @@ class TTS(nn.Module):
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
self.hps = hps
self.device = device
# load state_dict
checkpoint_dict = torch.load(ckpt_path, map_location=device)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
@ -172,11 +180,6 @@ class TTS(nn.Module):
else:
audio_list.append(audio)
if self.device == 'cuda':
torch.cuda.empty_cache()
if self.device == 'mps':
torch.mps.empty_cache()
toc = time.perf_counter()
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")
@ -184,6 +187,12 @@ class TTS(nn.Module):
if not stream:
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
del audio_list
if self.device == 'cuda':
torch.cuda.empty_cache()
if self.device == 'mps':
torch.mps.empty_cache()
if output_path is None:
return audio
else: