Merge branch 'develop' into csmsc-tts0

pull/3905/head
张春乔 9 months ago committed by GitHub
commit e137dbdbc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,96 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: 'data/lang_char/bpe_bpe_11297'
unit_type: 'spm'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 20.0
window_ms: 30.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 100
accum_grad: 4
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -15,12 +15,15 @@ import argparse
import os
import numpy as np
import paddle
from paddle import inference
from paddle.audio.datasets import ESC50
from paddle.audio.features import LogMelSpectrogram
from paddleaudio.backends import soundfile_load as load_audio
from scipy.special import softmax
import paddlespeech.utils
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
@ -56,7 +59,6 @@ def extract_features(files: str, **kwargs):
feature_extractor = LogMelSpectrogram(sr, **kwargs)
feat = feature_extractor(paddle.to_tensor(waveforms[i]))
feat = paddle.transpose(feat, perm=[1, 0]).unsqueeze(0)
feats.append(feat)
return np.stack(feats, axis=0)
@ -73,13 +75,18 @@ class Predictor(object):
enable_mkldnn=False):
self.batch_size = batch_size
model_file = os.path.join(model_dir, "inference.pdmodel")
params_file = os.path.join(model_dir, "inference.pdiparams")
if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
config = inference.Config(model_dir, 'inference')
config.disable_mkldnn()
else:
model_file = os.path.join(model_dir, 'inference.pdmodel')
params_file = os.path.join(model_dir, "inference.pdiparams")
assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.'
assert os.path.isfile(model_file) and os.path.isfile(
params_file), 'Please check model and parameter files.'
config = inference.Config(model_file, params_file)
config = inference.Config(model_file, params_file)
if device == "gpu":
# set GPU configs accordingly
# such as intialize the gpu memory, enable tensorrt

@ -39,7 +39,8 @@ if __name__ == '__main__':
input_spec=[
paddle.static.InputSpec(
shape=[None, None, 64], dtype=paddle.float32)
])
],
full_graph=True)
# Save in static graph model.
paddle.jit.save(model, os.path.join(args.output_dir, "inference"))

@ -37,8 +37,6 @@ if __name__ == "__main__":
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -1267,7 +1267,7 @@ class TransposeLast(nn.Layer):
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
trans_dim = paddle.arange(x.dim())
trans_dim = np.arange(x.dim())
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
return x.transpose(trans_dim)

@ -18,6 +18,7 @@ from pathlib import Path
import soundfile as sf
from paddle import inference
import paddlespeech.utils
from paddlespeech.t2s.frontend.zh_frontend import Frontend
@ -48,16 +49,27 @@ def main():
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
print("frontend done!")
speedyspeech_config = inference.Config(
str(Path(args.inference_dir) / "speedyspeech.pdmodel"),
str(Path(args.inference_dir) / "speedyspeech.pdiparams"))
# after paddle 3.0, support new inference interface
if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
speedyspeech_config = inference.Config(
str(Path(args.inference_dir)), "speedyspeech")
else:
speedyspeech_config = inference.Config(
str(Path(args.inference_dir) / "speedyspeech.pdmodel"),
str(Path(args.inference_dir) / "speedyspeech.pdiparams"))
speedyspeech_config.enable_use_gpu(100, 0)
speedyspeech_config.enable_memory_optim()
speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
pwg_config = inference.Config(
str(Path(args.inference_dir) / "pwg.pdmodel"),
str(Path(args.inference_dir) / "pwg.pdiparams"))
# after paddle 3.0, support new inference interface
if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
pwg_config = inference.Config(str(Path(args.inference_dir)), "pwg")
else:
pwg_config = inference.Config(
str(Path(args.inference_dir) / "pwg.pdmodel"),
str(Path(args.inference_dir) / "pwg.pdiparams"))
pwg_config.enable_use_gpu(100, 0)
pwg_config.enable_memory_optim()
pwg_predictor = inference.create_predictor(pwg_config)

@ -120,7 +120,11 @@ class SinusoidalPosEmb(nn.Layer):
self.dim = dim
def forward(self, x: paddle.Tensor):
x = paddle.cast(x, 'float32')
# check if x is 0-dim tensor, if so, add a dimension
if x.ndim == 0:
x = paddle.cast(x.unsqueeze(0), 'float32')
else:
x = paddle.cast(x, 'float32')
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = paddle.exp(paddle.arange(half_dim) * -emb)

@ -181,6 +181,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
# check if lengths is 0-dim tensor, if so, add a dimension
if lengths.ndim == 0:
bs = paddle.shape(lengths.unsqueeze(0))
else:

@ -11,3 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from packaging.version import Version
def satisfy_version(source: str, target: str, dev_allowed: bool=True) -> bool:
if dev_allowed and source.startswith('0.0.0'):
target_version = Version('0.0.0')
else:
target_version = Version(target)
source_version = Version(source)
return source_version >= target_version
def satisfy_paddle_version(target: str, dev_allowed: bool=True) -> bool:
import paddle
return satisfy_version(paddle.__version__, target, dev_allowed)

Loading…
Cancel
Save