1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-23 20:12:08 +03:00

feat(python tcp server): TTS tmp inference

This commit is contained in:
louistiti 2024-05-16 12:03:24 +08:00
parent 0e959412d5
commit 85af31b614
No known key found for this signature in database
GPG Key ID: 92CD6A2E497E1669
47 changed files with 464 additions and 422 deletions

2
.gitignore vendored
View File

@ -24,6 +24,7 @@ debug.log
leon.json
bridges/python/src/Pipfile.lock
tcp_server/src/Pipfile.lock
tcp_server/src/lib/tts/models/*.pth
!tcp_server/**/.gitkeep
!bridges/python/**/.gitkeep
!bridges/nodejs/**/.gitkeep
@ -32,7 +33,6 @@ tcp_server/src/Pipfile.lock
skills/**/src/settings.json
skills/**/memory/*.json
core/data/models/*.nlp
core/data/models/tts/*.pth
core/data/models/llm/*
package.json.backup
.python-version

View File

@ -1,16 +1,13 @@
import os
MODELS_PATH = os.path.join(
os.getcwd(),
'core',
'data',
'models'
)
SRC_PATH = os.path.join(os.getcwd(), 'tcp_server', 'src')
# TTS
TTS_MODEL_VERSION = 'V1'
TTS_MODEL_NAME = f'EN-Leon-{TTS_MODEL_VERSION}'
TTS_MODEL_FILE_NAME = f'{TTS_MODEL_NAME}.pth'
TTS_MODEL_FOLDER_PATH = os.path.join(MODELS_PATH, 'tts')
TTS_MODEL_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, TTS_MODEL_FILE_NAME)
TTS_LIB_PATH = os.path.join(SRC_PATH, 'lib', 'tts')
TTS_MODEL_FOLDER_PATH = os.path.join(TTS_LIB_PATH, 'models')
TTS_MODEL_CONFIG_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, 'config.json')
TTS_MODEL_PATH = os.path.join(TTS_MODEL_FOLDER_PATH, TTS_MODEL_FILE_NAME)
IS_TTS_ENABLED = os.environ.get('LEON_TTS', 'true') == 'true'

View File

@ -1,9 +1,11 @@
import socket
import json
import os
from typing import Union
import lib.nlp as nlp
from .tts.tts import TTS
from .tts.api import TTS
from .constants import TTS_MODEL_CONFIG_PATH, TTS_MODEL_PATH, IS_TTS_ENABLED
class TCPServer:
@ -13,12 +15,40 @@ class TCPServer:
self.tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn = None
self.addr = None
self.tts = TTS()
self.tts = None
@staticmethod
def log(*args, **kwargs):
print('[TCP Server]', *args, **kwargs)
def init_tts(self):
print('IS_TTS_ENABLED', IS_TTS_ENABLED)
# TODO: FIX IT
if not IS_TTS_ENABLED:
self.log('TTS is disabled')
return
if not os.path.exists(TTS_MODEL_CONFIG_PATH):
self.log(f'TTS model config not found at {TTS_MODEL_CONFIG_PATH}')
return
if not os.path.exists(TTS_MODEL_PATH):
self.log(f'TTS model not found at {TTS_MODEL_PATH}')
return
self.tts = TTS(language='EN',
device='auto',
config_path=TTS_MODEL_CONFIG_PATH,
ckpt_path=TTS_MODEL_PATH
)
text = 'Hello, I am Leon. How can I help you?'
speaker_ids = self.tts.hps.data.spk2id
output_path = 'output.wav'
speed = 1.0
self.tts.tts_to_file(text, speaker_ids['EN-Leon-V1'], output_path, speed=speed)
def init(self):
# Make sure to establish TCP connection by reusing the address so it does not conflict with port already in use
self.tcp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

View File

@ -17,14 +17,21 @@ class TTS(nn.Module):
config_path=None,
ckpt_path=None):
super().__init__()
self.log('Loading model...')
if device == 'auto':
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
else: self.log('GPU not available. CUDA is not installed?')
if torch.backends.mps.is_available(): device = 'mps'
if 'cuda' in device:
assert torch.cuda.is_available()
# config_path =
self.log(f'Device: {device}')
hps = utils.get_hparams_from_file(config_path)
num_languages = hps.num_languages
@ -54,6 +61,8 @@ class TTS(nn.Module):
language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
self.log('Model loaded')
@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
@ -125,3 +134,8 @@ class TTS(nn.Module):
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
@staticmethod
def log(*args, **kwargs):
print('[TTS]', *args, **kwargs)

View File

@ -3,15 +3,15 @@ import torch
from torch import nn
from torch.nn import functional as F
from melo import commons
from melo import modules
from melo import attentions
from lib.tts import commons
from lib.tts import modules
from lib.tts import attentions
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from melo.commons import init_weights, get_padding
import melo.monotonic_align as monotonic_align
from lib.tts.commons import init_weights, get_padding
import lib.tts.monotonic_align as monotonic_align
class DurationDiscriminator(nn.Module): # vits2

View File

@ -1,4 +1,328 @@
import os
import glob
import argparse
import logging
import json
import subprocess
import torch
from lib.tts.text import cleaned_text_to_sequence, get_bert
from lib.tts.text.cleaner import clean_text
from lib.tts import commons
MATPLOTLIB_FLAG = False
logger = logging.getLogger(__name__)
def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if getattr(hps.data, "disable_bert", False):
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
else:
bert = get_bert(norm_text, word2ph, language_str, device)
print('bert', bert)
del word2ph
assert bert.shape[-1] == len(phone), phone
if language_str == "ZH":
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
raise NotImplementedError()
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, phone, tone, language
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict.get("iteration", 0)
learning_rate = checkpoint_dict.get("learning_rate", 0.)
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
elif optimizer is None and not skip_optimizer:
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
new_opt_dict = optimizer.state_dict()
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
optimizer.load_state_dict(new_opt_dict)
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "emb_g" not in k
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except Exception as e:
print(e)
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
logger.warn(
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
)
else:
logger.error(f"{k} is not in the checkpoint")
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/base.json",
help="JSON file for configuration",
)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world-size', type=int, default=1)
parser.add_argument('--port', type=int, default=10000)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
parser.add_argument('--pretrain_G', type=str, default=None,
help='pretrain model')
parser.add_argument('--pretrain_D', type=str, default=None,
help='pretrain model D')
parser.add_argument('--pretrain_dur', type=str, default=None,
help='pretrain model duration')
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
os.makedirs(model_dir, exist_ok=True)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
hparams.pretrain_G = args.pretrain_G
hparams.pretrain_D = args.pretrain_D
hparams.pretrain_dur = args.pretrain_dur
hparams.port = args.port
return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
def name_key(_f):
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
def time_key(_f):
return os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
def x_sorted(_x):
return sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
def del_info(fn):
return logger.info(f".. Free up space by deleting ckpt {fn}")
def del_routine(x):
return [os.remove(x), del_info(x)]
[del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f:
@ -8,6 +332,47 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():

View File

@ -1,405 +0,0 @@
import os
import glob
import argparse
import logging
import json
import subprocess
import torch
from melo.text import cleaned_text_to_sequence, get_bert
from melo.text.cleaner import clean_text
from melo import commons
MATPLOTLIB_FLAG = False
logger = logging.getLogger(__name__)
def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if getattr(hps.data, "disable_bert", False):
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
else:
bert = get_bert(norm_text, word2ph, language_str, device)
print('bert', bert)
del word2ph
assert bert.shape[-1] == len(phone), phone
if language_str == "ZH":
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
raise NotImplementedError()
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, phone, tone, language
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict.get("iteration", 0)
learning_rate = checkpoint_dict.get("learning_rate", 0.)
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
elif optimizer is None and not skip_optimizer:
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
new_opt_dict = optimizer.state_dict()
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
optimizer.load_state_dict(new_opt_dict)
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "emb_g" not in k
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except Exception as e:
print(e)
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
logger.warn(
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
)
else:
logger.error(f"{k} is not in the checkpoint")
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/base.json",
help="JSON file for configuration",
)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world-size', type=int, default=1)
parser.add_argument('--port', type=int, default=10000)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
parser.add_argument('--pretrain_G', type=str, default=None,
help='pretrain model')
parser.add_argument('--pretrain_D', type=str, default=None,
help='pretrain model D')
parser.add_argument('--pretrain_dur', type=str, default=None,
help='pretrain model duration')
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
os.makedirs(model_dir, exist_ok=True)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
hparams.pretrain_G = args.pretrain_G
hparams.pretrain_D = args.pretrain_D
hparams.pretrain_dur = args.pretrain_dur
hparams.port = args.port
return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
def name_key(_f):
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
def time_key(_f):
return os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
def x_sorted(_x):
return sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
def del_info(fn):
return logger.info(f".. Free up space by deleting ckpt {fn}")
def del_routine(x):
return [os.remove(x), del_info(x)]
[del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()

View File

@ -0,0 +1,40 @@
import json
def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()

View File

@ -14,4 +14,5 @@ tcp_server_host = os.environ.get('LEON_PY_TCP_SERVER_HOST', '0.0.0.0')
tcp_server_port = os.environ.get('LEON_PY_TCP_SERVER_PORT', 1342)
tcp_server = TCPServer(tcp_server_host, tcp_server_port)
tcp_server.init_tts()
tcp_server.init()