From c09466ebbe32e61888a98f6de52144e527998005 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 11 Jan 2022 14:46:57 +0800 Subject: [PATCH 1/6] Add ECAPA_TDNN. (#1295) --- paddlespeech/vector/models/ecapa_tdnn.py | 417 +++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 paddlespeech/vector/models/ecapa_tdnn.py diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py new file mode 100644 index 00000000..5512f509 --- /dev/null +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -0,0 +1,417 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def length_to_mask(length, max_len=None, dtype=None): + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().astype( + 'int').item() # using arange to generate mask + mask = paddle.arange( + max_len, dtype=length.dtype).expand( + (len(length), max_len)) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + mask = paddle.to_tensor(mask, dtype=dtype) + return mask + + +class Conv1d(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + groups=1, + bias=True, + padding_mode="reflect", ): + super(Conv1d, self).__init__() + + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + + self.conv = nn.Conv1D( + in_channels, + out_channels, + self.kernel_size, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=groups, + bias_attr=bias, ) + + def forward(self, x): + if self.padding == "same": + x = self._manage_padding(x, self.kernel_size, self.dilation, + self.stride) + else: + raise ValueError("Padding must be 'same'. Got {self.padding}") + + return self.conv(x) + + def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + L_in = x.shape[-1] # Detecting input shape + padding = self._get_padding_elem(L_in, stride, kernel_size, + dilation) # Time padding + x = F.pad( + x, padding, mode=self.padding_mode, + data_format="NCL") # Applying padding + return x + + def _get_padding_elem(self, + L_in: int, + stride: int, + kernel_size: int, + dilation: int): + if stride > 1: + n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) + L_out = stride * (n_steps - 1) + kernel_size * dilation + padding = [kernel_size // 2, kernel_size // 2] + else: + L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 + + padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] + + return padding + + +class BatchNorm1d(nn.Layer): + def __init__( + self, + input_size, + eps=1e-05, + momentum=0.9, + weight_attr=None, + bias_attr=None, + data_format='NCL', + use_global_stats=None, ): + super(BatchNorm1d, self).__init__() + + self.norm = nn.BatchNorm1D( + input_size, + epsilon=eps, + momentum=momentum, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format, + use_global_stats=use_global_stats, ) + + def forward(self, x): + x_n = self.norm(x) + return x_n + + +class TDNNBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + activation=nn.ReLU, ): + super(TDNNBlock, self).__init__() + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, ) + self.activation = activation() + self.norm = BatchNorm1d(input_size=out_channels) + + def forward(self, x): + return self.norm(self.activation(self.conv(x))) + + +class Res2NetBlock(nn.Layer): + def __init__(self, in_channels, out_channels, scale=8, dilation=1): + super(Res2NetBlock, self).__init__() + assert in_channels % scale == 0 + assert out_channels % scale == 0 + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.LayerList([ + TDNNBlock( + in_channel, hidden_channel, kernel_size=3, dilation=dilation) + for i in range(scale - 1) + ]) + self.scale = scale + + def forward(self, x): + y = [] + for i, x_i in enumerate(paddle.chunk(x, self.scale, axis=1)): + if i == 0: + y_i = x_i + elif i == 1: + y_i = self.blocks[i - 1](x_i) + else: + y_i = self.blocks[i - 1](x_i + y_i) + y.append(y_i) + y = paddle.concat(y, axis=1) + return y + + +class SEBlock(nn.Layer): + def __init__(self, in_channels, se_channels, out_channels): + super(SEBlock, self).__init__() + + self.conv1 = Conv1d( + in_channels=in_channels, out_channels=se_channels, kernel_size=1) + self.relu = paddle.nn.ReLU() + self.conv2 = Conv1d( + in_channels=se_channels, out_channels=out_channels, kernel_size=1) + self.sigmoid = paddle.nn.Sigmoid() + + def forward(self, x, lengths=None): + L = x.shape[-1] + if lengths is not None: + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + total = mask.sum(axis=2, keepdim=True) + s = (x * mask).sum(axis=2, keepdim=True) / total + else: + s = x.mean(axis=2, keepdim=True) + + s = self.relu(self.conv1(s)) + s = self.sigmoid(self.conv2(s)) + + return s * x + + +class AttentiveStatisticsPooling(nn.Layer): + def __init__(self, channels, attention_channels=128, global_context=True): + super().__init__() + + self.eps = 1e-12 + self.global_context = global_context + if global_context: + self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) + else: + self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1) + + def forward(self, x, lengths=None): + C, L = x.shape[1], x.shape[2] # KP: (N, C, L) + + def _compute_statistics(x, m, axis=2, eps=self.eps): + mean = (m * x).sum(axis) + std = paddle.sqrt( + (m * (x - mean.unsqueeze(axis)).pow(2)).sum(axis).clip(eps)) + return mean, std + + if lengths is None: + lengths = paddle.ones([x.shape[0]]) + + # Make binary mask of shape [N, 1, L] + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + if self.global_context: + total = mask.sum(axis=2, keepdim=True).astype('float32') + mean, std = _compute_statistics(x, mask / total) + mean = mean.unsqueeze(2).tile((1, 1, L)) + std = std.unsqueeze(2).tile((1, 1, L)) + attn = paddle.concat([x, mean, std], axis=1) + else: + attn = x + + # Apply layers + attn = self.conv(self.tanh(self.tdnn(attn))) + + # Filter out zero-paddings + attn = paddle.where( + mask.tile((1, C, 1)) == 0, + paddle.ones_like(attn) * float("-inf"), attn) + + attn = F.softmax(attn, axis=2) + mean, std = _compute_statistics(x, attn) + + # Append mean and std of the batch + pooled_stats = paddle.concat((mean, std), axis=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SERes2NetBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + activation=nn.ReLU, ): + super(SERes2NetBlock, self).__init__() + self.out_channels = out_channels + self.tdnn1 = TDNNBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, + res2net_scale, dilation) + self.tdnn2 = TDNNBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, ) + self.se_block = SEBlock(out_channels, se_channels, out_channels) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, ) + + def forward(self, x, lengths=None): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.tdnn1(x) + x = self.res2net_block(x) + x = self.tdnn2(x) + x = self.se_block(x, lengths) + + return x + residual + + +class ECAPA_TDNN(nn.Layer): + def __init__( + self, + input_size, + lin_neurons=192, + activation=nn.ReLU, + channels=[512, 512, 512, 512, 1536], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=128, + res2net_scale=8, + se_channels=128, + global_context=True, ): + + super(ECAPA_TDNN, self).__init__() + assert len(channels) == len(kernel_sizes) + assert len(channels) == len(dilations) + self.channels = channels + self.blocks = nn.LayerList() + self.emb_size = lin_neurons + + # The initial TDNN layer + self.blocks.append( + TDNNBlock( + input_size, + channels[0], + kernel_sizes[0], + dilations[0], + activation, )) + + # SE-Res2Net layers + for i in range(1, len(channels) - 1): + self.blocks.append( + SERes2NetBlock( + channels[i - 1], + channels[i], + res2net_scale=res2net_scale, + se_channels=se_channels, + kernel_size=kernel_sizes[i], + dilation=dilations[i], + activation=activation, )) + + # Multi-layer feature aggregation + self.mfa = TDNNBlock( + channels[-1], + channels[-1], + kernel_sizes[-1], + dilations[-1], + activation, ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + channels[-1], + attention_channels=attention_channels, + global_context=global_context, ) + self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) + + # Final linear transformation + self.fc = Conv1d( + in_channels=channels[-1] * 2, + out_channels=self.emb_size, + kernel_size=1, ) + + def forward(self, x, lengths=None): + xl = [] + for layer in self.blocks: + try: + x = layer(x, lengths=lengths) + except TypeError: + x = layer(x) + xl.append(x) + + # Multi-layer feature aggregation + x = paddle.concat(xl[1:], axis=1) + x = self.mfa(x) + + # Attentive Statistical Pooling + x = self.asp(x, lengths=lengths) + x = self.asp_bn(x) + + # Final linear transformation + x = self.fc(x) + + return x + + +class Classifier(nn.Layer): + def __init__(self, backbone, num_class, dtype=paddle.float32): + super(Classifier, self).__init__() + self.backbone = backbone + self.params = nn.ParameterList([ + paddle.create_parameter( + shape=[num_class, self.backbone.emb_size], dtype=dtype) + ]) + + def forward(self, x): + emb = self.backbone(x.transpose([0, 2, 1])).transpose([0, 2, 1]) + logits = F.linear( + F.normalize(emb.squeeze(1)), + F.normalize(self.params[0]).transpose([1, 0])) + + return logits From 010aa65b2b59f2578ea73114c38f26ba5c37a36c Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Tue, 11 Jan 2022 14:48:36 +0800 Subject: [PATCH 2/6] [cli] asr - support English, decode_metod and unified config (#1297) * fix config, test=asr * fix config, test=doc_fix * add en and decode_method for cli/asr, test=asr * test=asr * fix, test=doc_fix --- CHANGELOG.md | 9 ++ paddlespeech/cli/asr/infer.py | 93 ++++++------- paddlespeech/s2t/frontend/normalizer.py | 13 +- utils/generate_infer_yaml.py | 178 ++++++++++++++++++++++++ 4 files changed, 239 insertions(+), 54 deletions(-) create mode 100644 utils/generate_infer_yaml.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dc68c6f..5ffe8098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,2 +1,11 @@ # Changelog + +Date: 2022-1-10, Author: Jackwaterveg. +Add features to: CLI: + - Support English (librispeech/asr1/transformer). + - Support choosing `decode_method` for conformer and transformer models. + - Refactor the config, using the unified config. + - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1297 + +*** diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 8de96476..aa4e31d9 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -46,19 +46,29 @@ pretrained_models = { # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" "conformer_wenetspeech-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', 'md5': - '54e7a558a6e020c2f5fb224874943f97', + '76cb19ed857e6623856b7cd7ebbfeda4', 'cfg_path': - 'conf/conformer.yaml', + 'model.yaml', 'ckpt_path': 'exp/conformer/checkpoints/wenetspeech', }, + "transformer_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '2c667da24922aad391eacafe37bc1660', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, } model_alias = { - "ds2_offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "ds2_online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "deepspeech2offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "deepspeech2online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": "paddlespeech.s2t.models.u2:U2Model", @@ -85,7 +95,7 @@ class ASRExecutor(BaseExecutor): '--lang', type=str, default='zh', - help='Choose model language. zh or en') + help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]') self.parser.add_argument( "--sample_rate", type=int, @@ -97,6 +107,12 @@ class ASRExecutor(BaseExecutor): type=str, default=None, help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='attention_rescoring', + choices=['ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', 'attention_rescoring'], + help='only support transformer and conformer model') self.parser.add_argument( '--ckpt_path', type=str, @@ -136,6 +152,7 @@ class ASRExecutor(BaseExecutor): lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. @@ -159,51 +176,36 @@ class ASRExecutor(BaseExecutor): else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") - res_path = os.path.dirname( + self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) - self.config.decoding.decoding_method = "attention_rescoring" with UpdateConfig(self.config): - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.mean_std_filepath = os.path.join( - res_path, self.config.collator.cmvn_path) + self.vocab = self.config.vocab_filepath + self.config.decode.lang_model_path = os.path.join(res_path, self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.collate_fn_test.feature_size - self.config.model.output_dim = self.text_feature.vocab_size + unit_type=self.config.unit_type, + vocab=self.vocab) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.augmentation_config = os.path.join( - res_path, self.config.collator.augmentation_config) - self.config.collator.spm_model_prefix = os.path.join( - res_path, self.config.collator.spm_model_prefix) + self.config.spm_model_prefix = os.path.join(self.res_path, self.config.spm_model_prefix) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.config.collator.feat_dim - self.config.model.output_dim = self.text_feature.vocab_size + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + self.config.decode.decoding_method = decode_method else: raise Exception("wrong type") - # Enter the path of model root - model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} model_class = dynamic_import(model_name, model_alias) - model_conf = self.config.model - logger.info(model_conf) + model_conf = self.config model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -222,7 +224,7 @@ class ASRExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] @@ -236,18 +238,7 @@ class ASRExecutor(BaseExecutor): elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: logger.info("get the preprocess conf") - preprocess_conf_file = self.config.collator.augmentation_config - # redirect the cmvn path - with io.open(preprocess_conf_file, encoding="utf-8") as f: - preprocess_conf = yaml.safe_load(f) - for idx, process in enumerate(preprocess_conf["process"]): - if process['type'] == "cmvn_json": - preprocess_conf["process"][idx][ - "cmvn_path"] = os.path.join( - self.res_path, - preprocess_conf["process"][idx]["cmvn_path"]) - break - logger.info(preprocess_conf) + preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("read the audio file") @@ -289,10 +280,10 @@ class ASRExecutor(BaseExecutor): Model inference and result stored in self.output. """ - cfg = self.config.decoding + cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: result_transcripts = self.model.decode( audio, audio_len, @@ -414,12 +405,13 @@ class ASRExecutor(BaseExecutor): config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input + decode_method = parser_args.decode_method force_yes = parser_args.yes device = parser_args.device try: res = self(audio_file, model, lang, sample_rate, config, ckpt_path, - force_yes, device) + decode_method, force_yes, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: @@ -434,6 +426,7 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, config: os.PathLike=None, ckpt_path: os.PathLike=None, + decode_method: str='attention_rescoring', force_yes: bool=False, device=paddle.get_device()): """ @@ -442,7 +435,7 @@ class ASRExecutor(BaseExecutor): audio_file = os.path.abspath(audio_file) self._check(audio_file, sample_rate, force_yes) paddle.set_device(device) - self._init_from_path(model, lang, sample_rate, config, ckpt_path) + self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr. diff --git a/paddlespeech/s2t/frontend/normalizer.py b/paddlespeech/s2t/frontend/normalizer.py index 017851e6..b596b2ab 100644 --- a/paddlespeech/s2t/frontend/normalizer.py +++ b/paddlespeech/s2t/frontend/normalizer.py @@ -117,7 +117,8 @@ class FeatureNormalizer(object): self._compute_mean_std(manifest_path, featurize_func, num_samples, num_workers) else: - self._read_mean_std_from_file(mean_std_filepath) + mean_std = mean_std_filepath + self._read_mean_std_from_file(mean_std) def apply(self, features): """Normalize features to be of zero mean and unit stddev. @@ -131,10 +132,14 @@ class FeatureNormalizer(object): """ return (features - self._mean) * self._istd - def _read_mean_std_from_file(self, filepath, eps=1e-20): + def _read_mean_std_from_file(self, mean_std, eps=1e-20): """Load mean and std from file.""" - filetype = filepath.split(".")[-1] - mean, istd = load_cmvn(filepath, filetype=filetype) + if isinstance(mean_std, list): + mean = mean_std[0]['cmvn_stats']['mean'] + istd = mean_std[0]['cmvn_stats']['istd'] + else: + filetype = mean_std.split(".")[-1] + mean, istd = load_cmvn(mean_std, filetype=filetype) self._mean = np.expand_dims(mean, axis=0) self._istd = np.expand_dims(istd, axis=0) diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py new file mode 100644 index 00000000..d2a6777c --- /dev/null +++ b/utils/generate_infer_yaml.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +''' + Merge training configs into a single inference config. + The single inference config is for CLI, which only takes a single config to do inferencing. + The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file. +''' + +import yaml +import json +import os +import argparse +import math +from yacs.config import CfgNode + +from paddlespeech.s2t.frontend.utility import load_dict +from contextlib import redirect_stdout + + +def save(save_path, config): + with open(save_path, 'w') as fp: + with redirect_stdout(fp): + print(config.dump()) + + +def load(save_path): + config = CfgNode(new_allowed=True) + config.merge_from_file(save_path) + return config + +def load_json(json_path): + with open(json_path) as f: + json_content = json.load(f) + return json_content + +def remove_config_part(config, key_list): + if len(key_list) == 0: + return + for i in range(len(key_list) -1): + config = config[key_list[i]] + config.pop(key_list[-1]) + +def load_cmvn_from_json(cmvn_stats): + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn_stats = {"mean":means, "istd":variance} + return cmvn_stats + +def merge_configs( + conf_path = "conf/conformer.yaml", + preprocess_path = "conf/preprocess.yaml", + decode_path = "conf/tuning/decode.yaml", + vocab_path = "data/vocab.txt", + cmvn_path = "data/mean_std.json", + save_path = "conf/conformer_infer.yaml", + ): + + # Load the configs + config = load(conf_path) + decode_config = load(decode_path) + vocab_list = load_dict(vocab_path) + cmvn_stats = load_json(cmvn_path) + if os.path.exists(preprocess_path): + preprocess_config = load(preprocess_path) + for idx, process in enumerate(preprocess_config["process"]): + if process['type'] == "cmvn_json": + preprocess_config["process"][idx][ + "cmvn_path"] = cmvn_stats + break + + config.preprocess_config = preprocess_config + else: + cmvn_stats = load_cmvn_from_json(cmvn_stats) + config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] + config.augmentation_config = '' + + # Updata the config + config.vocab_filepath = vocab_list + config.input_dim = config.feat_dim + config.output_dim = len(config.vocab_filepath) + config.decode = decode_config + # Remove some parts of the config + + if os.path.exists(preprocess_path): + remove_train_list = ["train_manifest", + "dev_manifest", + "test_manifest", + "n_epoch", + "accum_grad", + "global_grad_clip", + "optim", + "optim_conf", + "scheduler", + "scheduler_conf", + "log_interval", + "checkpoint", + "shuffle_method", + "weight_decay", + "ctc_grad_norm_type", + "minibatches", + "subsampling_factor", + "batch_bins", + "batch_count", + "batch_frames_in", + "batch_frames_inout", + "batch_frames_out", + "sortagrad", + "feat_dim", + "stride_ms", + "window_ms", + "batch_size", + "maxlen_in", + "maxlen_out", + ] + else: + remove_train_list = ["train_manifest", + "dev_manifest", + "test_manifest", + "n_epoch", + "accum_grad", + "global_grad_clip", + "log_interval", + "checkpoint", + "lr", + "lr_decay", + "batch_size", + "shuffle_method", + "weight_decay", + "sortagrad", + "num_workers", + ] + + for item in remove_train_list: + try: + remove_config_part(config, [item]) + except: + print ( item + " " +"can not be removed") + + # Save the config + save(save_path, config) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='Config merge', add_help=True) + parser.add_argument( + '--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file') + parser.add_argument( + '--pre_pth', type=str, default= "conf/preprocess.yaml", help='') + parser.add_argument( + '--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='') + parser.add_argument( + '--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='') + parser.add_argument( + '--cmvn_pth', type=str, default= "data/mean_std.json", help='') + parser.add_argument( + '--save_pth', type=str, default= "conf/transformer_infer.yaml", help='') + parser_args = parser.parse_args() + + merge_configs( + conf_path = parser_args.cfg_pth, + decode_path = parser_args.dcd_pth, + preprocess_path = parser_args.pre_pth, + vocab_path = parser_args.vb_pth, + cmvn_path = parser_args.cmvn_pth, + save_path = parser_args.save_pth, + ) + + From 52a8b2f3209b9bd5e6809f9d38348962e2627c75 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 11 Jan 2022 15:04:23 +0800 Subject: [PATCH 3/6] Add ECAPA_TDNN. (#1301) --- paddlespeech/vector/models/ecapa_tdnn.py | 44 ++++++++++-------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py index 5512f509..e493b800 100644 --- a/paddlespeech/vector/models/ecapa_tdnn.py +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -47,7 +47,7 @@ class Conv1d(nn.Layer): groups=1, bias=True, padding_mode="reflect", ): - super(Conv1d, self).__init__() + super().__init__() self.kernel_size = kernel_size self.stride = stride @@ -110,7 +110,7 @@ class BatchNorm1d(nn.Layer): bias_attr=None, data_format='NCL', use_global_stats=None, ): - super(BatchNorm1d, self).__init__() + super().__init__() self.norm = nn.BatchNorm1D( input_size, @@ -134,7 +134,7 @@ class TDNNBlock(nn.Layer): kernel_size, dilation, activation=nn.ReLU, ): - super(TDNNBlock, self).__init__() + super().__init__() self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, @@ -149,7 +149,7 @@ class TDNNBlock(nn.Layer): class Res2NetBlock(nn.Layer): def __init__(self, in_channels, out_channels, scale=8, dilation=1): - super(Res2NetBlock, self).__init__() + super().__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 @@ -179,7 +179,7 @@ class Res2NetBlock(nn.Layer): class SEBlock(nn.Layer): def __init__(self, in_channels, se_channels, out_channels): - super(SEBlock, self).__init__() + super().__init__() self.conv1 = Conv1d( in_channels=in_channels, out_channels=se_channels, kernel_size=1) @@ -275,7 +275,7 @@ class SERes2NetBlock(nn.Layer): kernel_size=1, dilation=1, activation=nn.ReLU, ): - super(SERes2NetBlock, self).__init__() + super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( in_channels, @@ -313,7 +313,7 @@ class SERes2NetBlock(nn.Layer): return x + residual -class ECAPA_TDNN(nn.Layer): +class EcapaTdnn(nn.Layer): def __init__( self, input_size, @@ -327,7 +327,7 @@ class ECAPA_TDNN(nn.Layer): se_channels=128, global_context=True, ): - super(ECAPA_TDNN, self).__init__() + super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations) self.channels = channels @@ -377,6 +377,16 @@ class ECAPA_TDNN(nn.Layer): kernel_size=1, ) def forward(self, x, lengths=None): + """ + Compute embeddings. + + Args: + x (paddle.Tensor): Input log-fbanks with shape (N, n_mels, T). + lengths (paddle.Tensor, optional): Length proportions of batch length with shape (N). Defaults to None. + + Returns: + paddle.Tensor: Output embeddings with shape (N, self.emb_size, 1) + """ xl = [] for layer in self.blocks: try: @@ -397,21 +407,3 @@ class ECAPA_TDNN(nn.Layer): x = self.fc(x) return x - - -class Classifier(nn.Layer): - def __init__(self, backbone, num_class, dtype=paddle.float32): - super(Classifier, self).__init__() - self.backbone = backbone - self.params = nn.ParameterList([ - paddle.create_parameter( - shape=[num_class, self.backbone.emb_size], dtype=dtype) - ]) - - def forward(self, x): - emb = self.backbone(x.transpose([0, 2, 1])).transpose([0, 2, 1]) - logits = F.linear( - F.normalize(emb.squeeze(1)), - F.normalize(self.params[0]).transpose([1, 0])) - - return logits From 9c1efde2e3220ad40c299aa588bad00a8e71aba4 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 11 Jan 2022 15:27:50 +0800 Subject: [PATCH 4/6] more mergify label (#1308) --- .mergify.yml | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/.mergify.yml b/.mergify.yml index 2c30721f..3347c6dc 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -32,6 +32,12 @@ pull_request_rules: actions: label: remove: ["conflicts"] + - name: "auto add label=Dataset" + conditions: + - files~=^dataset/ + actions: + label: + add: ["Dataset"] - name: "auto add label=S2T" conditions: - files~=^paddlespeech/s2t/ @@ -50,18 +56,30 @@ pull_request_rules: actions: label: add: ["Audio"] - - name: "auto add label=TextProcess" + - name: "auto add label=Vector" + conditions: + - files~=^paddlespeech/vector/ + actions: + label: + add: ["Vector"] + - name: "auto add label=Text" conditions: - files~=^paddlespeech/text/ actions: label: - add: ["TextProcess"] + add: ["Text"] - name: "auto add label=Example" conditions: - files~=^examples/ actions: label: add: ["Example"] + - name: "auto add label=CLI" + conditions: + - files~=^paddlespeech/cli + actions: + label: + add: ["CLI"] - name: "auto add label=Demo" conditions: - files~=^demos/ @@ -70,13 +88,13 @@ pull_request_rules: add: ["Demo"] - name: "auto add label=README" conditions: - - files~=README.md + - files~=(README.md|READEME_cn.md) actions: label: add: ["README"] - name: "auto add label=Documentation" conditions: - - files~=^docs/ + - files~=^(docs/|CHANGELOG.md|paddleaudio/CHANGELOG.md) actions: label: add: ["Documentation"] @@ -88,10 +106,16 @@ pull_request_rules: add: ["CI"] - name: "auto add label=Installation" conditions: - - files~=^(tools/|setup.py|setup.sh) + - files~=^(tools/|setup.py|setup.cfg|setup_audio.py) actions: label: add: ["Installation"] + - name: "auto add label=Test" + conditions: + - files~=^(tests/) + actions: + label: + add: ["Test"] - name: "auto add label=mergify" conditions: - files~=^.mergify.yml From 27bb76bdb92330ccbed1d70f47ac38dae2a3213e Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 11 Jan 2022 08:53:25 +0000 Subject: [PATCH 5/6] fix tone_sandhi of yi, test=tts --- paddlespeech/t2s/frontend/tone_sandhi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddlespeech/t2s/frontend/tone_sandhi.py b/paddlespeech/t2s/frontend/tone_sandhi.py index 6ba567bb..5264e068 100644 --- a/paddlespeech/t2s/frontend/tone_sandhi.py +++ b/paddlespeech/t2s/frontend/tone_sandhi.py @@ -65,6 +65,7 @@ class ToneSandhi(): self.must_not_neural_tone_words = { "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子" } + self.punc = ":,;。?!“”‘’':,;.?!" # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 # e.g. @@ -147,7 +148,9 @@ class ToneSandhi(): finals[i] = finals[i][:-1] + "2" # "一" before non-tone4 should be yi4, e.g. 一天 else: - finals[i] = finals[i][:-1] + "4" + # "一" 后面如果是标点,还读一声 + if word[i + 1] not in self.punc: + finals[i] = finals[i][:-1] + "4" return finals def _split_word(self, word: str) -> List[str]: From c893d0d8287f674a6fd32a83662b6cb529cbee22 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Tue, 11 Jan 2022 19:35:27 +0800 Subject: [PATCH 6/6] update the required version of yacs, test=doc_fix (#1311) --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index e567dfa7..c6889318 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,5 +43,5 @@ typeguard unidecode visualdl webrtcvad -yacs +yacs~=0.1.8 yq diff --git a/setup.py b/setup.py index 75b3fe5c..5d4ff80f 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ requirements = { "typeguard", "visualdl", "webrtcvad", - "yacs", + "yacs~=0.1.8", ], "develop": [ "ConfigArgParse",