1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-23 09:43:19 +03:00

feat(python tcp server): empty cache on MPS PyTorch backend for TTS

This commit is contained in:
louistiti 2024-06-19 08:59:20 +08:00
parent 5b176c90f5
commit 963660db19
2 changed files with 9 additions and 5 deletions

View File

@ -34,7 +34,7 @@ _<p align="center">Your open-source personal assistant.</p>_
> [!IMPORTANT]
> Due to all the new major changes coming to Leon AI, the development branch might be unstable. It is recommended to use the older version under the master branch.
>
>
> Please note that older versions do not make use of any foundation model, which will be introduced in upcoming versions.
**Outdated Documentation**

View File

@ -31,13 +31,14 @@ class TTS(nn.Module):
if torch.cuda.is_available():
device = 'cuda'
else:
self.log('GPU not available. CUDA is not installed?')
self.log('Using CUDA (Compute Unified Device Architecture)')
if torch.backends.mps.is_available():
device = 'mps'
self.log('Using MPS (Metal Performance Shaders)')
if 'cuda' in device:
assert torch.cuda.is_available()
if 'mps' in device:
assert torch.backends.mps.is_available()
self.log(f'Device: {device}')
@ -171,7 +172,10 @@ class TTS(nn.Module):
else:
audio_list.append(audio)
torch.cuda.empty_cache()
if self.device == 'cuda':
torch.cuda.empty_cache()
if self.device == 'mps':
torch.mps.empty_cache()
toc = time.perf_counter()
self.log(f"Time taken to generate audio: {toc - tic:0.4f} seconds")