diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index e4364f70a..7dc01c40a 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -76,7 +76,7 @@ class TextFeaturizer(): Args: text (str): Text. - + Returns: List[int]: List of token indices. """ @@ -89,7 +89,7 @@ class TextFeaturizer(): def defeaturize(self, idxs): """Convert a list of token indices to text string, - ignore index after eos_id. + ignore index after eos_id. Args: idxs (List[int]): List of token indices. @@ -196,7 +196,12 @@ class TextFeaturizer(): [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - - unk_id = vocab_list.index(UNK) - eos_id = vocab_list.index(EOS) + if UNK in vocab_list: + unk_id = vocab_list.index(UNK) + else: + unk_id = -1 + if EOS in vocab_list: + eos_id = vocab_list.index(EOS) + else: + eos_id = -1 return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 73b3a4ba6..6ace4fc6d 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -130,7 +130,8 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" - mean, istd = load_cmvn(filepath, filetype='json') + filetype = filepath.split(".")[-1] + mean, istd = load_cmvn(filepath, filetype=filetype) self._mean = np.expand_dims(mean, axis=0) self._istd = np.expand_dims(istd, axis=0) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 72dfc98dd..c6781cd4e 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -69,19 +69,19 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, - in seconds for raw wav, in frame numbers for feature data. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, - in seconds for raw wav, in frame numbers for feature data. + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): + min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: @@ -131,7 +131,7 @@ def rms_to_dbfs(rms: float): """Root Mean Square to dBFS. https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. - + dB = dBFS + 3.0103 dBFS = db - 3.0103 e.g. 0 dB = -3.0103 dBFS @@ -146,26 +146,26 @@ def rms_to_dbfs(rms: float): def max_dbfs(sample_data: np.ndarray): - """Peak dBFS based on the maximum energy sample. + """Peak dBFS based on the maximum energy sample. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) def mean_dbfs(sample_data): - """Peak dBFS based on the RMS energy. + """Peak dBFS based on the RMS energy. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ return rms_to_dbfs( math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) @@ -185,7 +185,7 @@ def gain_db_to_ratio(gain_db: float): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): """Nomalize audio to dBFS. - + Args: sample_data (np.ndarray): input wave samples, [-1, 1]. dbfs (float, optional): target dBFS. Defaults to -3.0103. @@ -284,6 +284,13 @@ def load_cmvn(cmvn_file: str, filetype: str): cmvn = _load_json_cmvn(cmvn_file) elif filetype == "kaldi": cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] else: raise ValueError(f"cmvn file type no support: {filetype}") return cmvn[0], cmvn[1] diff --git a/examples/v18_to_v2x/.gitignore b/examples/v18_to_v2x/.gitignore new file mode 100644 index 000000000..a9a5aecf4 --- /dev/null +++ b/examples/v18_to_v2x/.gitignore @@ -0,0 +1 @@ +tmp diff --git a/examples/v18_to_v2x/deepspeech2x/__init__.py b/examples/v18_to_v2x/deepspeech2x/__init__.py new file mode 100644 index 000000000..d85a3dde7 --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/__init__.py @@ -0,0 +1,370 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import List +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn +from paddle.fluid import core +from paddle.nn import functional as F + +from deepspeech.utils.log import Log + +#TODO(Hui Zhang): remove fluid import +logger = Log(__name__).getlog() + +########### hcak logging ############# +logger.warn = logger.warning + +########### hcak paddle ############# +paddle.half = 'float16' +paddle.float = 'float32' +paddle.double = 'float64' +paddle.short = 'int16' +paddle.int = 'int32' +paddle.long = 'int64' +paddle.uint16 = 'uint16' +paddle.cdouble = 'complex128' + + +def convert_dtype_to_string(tensor_dtype): + """ + Convert the data type in numpy to the data type in Paddle + Args: + tensor_dtype(core.VarDesc.VarType): the data type in numpy. + Returns: + core.VarDesc.VarType: the data type in Paddle. + """ + dtype = tensor_dtype + if dtype == core.VarDesc.VarType.FP32: + return paddle.float32 + elif dtype == core.VarDesc.VarType.FP64: + return paddle.float64 + elif dtype == core.VarDesc.VarType.FP16: + return paddle.float16 + elif dtype == core.VarDesc.VarType.INT32: + return paddle.int32 + elif dtype == core.VarDesc.VarType.INT16: + return paddle.int16 + elif dtype == core.VarDesc.VarType.INT64: + return paddle.int64 + elif dtype == core.VarDesc.VarType.BOOL: + return paddle.bool + elif dtype == core.VarDesc.VarType.BF16: + # since there is still no support for bfloat16 in NumPy, + # uint16 is used for casting bfloat16 + return paddle.uint16 + elif dtype == core.VarDesc.VarType.UINT8: + return paddle.uint8 + elif dtype == core.VarDesc.VarType.INT8: + return paddle.int8 + elif dtype == core.VarDesc.VarType.COMPLEX64: + return paddle.complex64 + elif dtype == core.VarDesc.VarType.COMPLEX128: + return paddle.complex128 + else: + raise ValueError("Not supported tensor dtype %s" % dtype) + + +if not hasattr(paddle, 'softmax'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle, 'log_softmax'): + logger.warn("register user log_softmax to paddle, remove this when fixed!") + setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) + +if not hasattr(paddle, 'sigmoid'): + logger.warn("register user sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle, 'log_sigmoid'): + logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) + +if not hasattr(paddle, 'relu'): + logger.warn("register user relu to paddle, remove this when fixed!") + setattr(paddle, 'relu', paddle.nn.functional.relu) + + +def cat(xs, dim=0): + return paddle.concat(xs, axis=dim) + + +if not hasattr(paddle, 'cat'): + logger.warn( + "override cat of paddle if exists or register, remove this when fixed!") + paddle.cat = cat + + +########### hcak paddle.Tensor ############# +def item(x: paddle.Tensor): + return x.numpy().item() + + +if not hasattr(paddle.Tensor, 'item'): + logger.warn( + "override item of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.item = item + + +def func_long(x: paddle.Tensor): + return paddle.cast(x, paddle.long) + + +if not hasattr(paddle.Tensor, 'long'): + logger.warn( + "override long of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.long = func_long + +if not hasattr(paddle.Tensor, 'numel'): + logger.warn( + "override numel of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.numel = paddle.numel + + +def new_full(x: paddle.Tensor, + size: Union[List[int], Tuple[int], paddle.Tensor], + fill_value: Union[float, int, bool, paddle.Tensor], + dtype=None): + return paddle.full(size, fill_value, dtype=x.dtype) + + +if not hasattr(paddle.Tensor, 'new_full'): + logger.warn( + "override new_full of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.new_full = new_full + + +def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: + if convert_dtype_to_string(xs.dtype) == paddle.bool: + xs = xs.astype(paddle.int) + return xs.equal( + paddle.to_tensor( + ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place)) + + +if not hasattr(paddle.Tensor, 'eq'): + logger.warn( + "override eq of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.eq = eq + +if not hasattr(paddle, 'eq'): + logger.warn( + "override eq of paddle if exists or register, remove this when fixed!") + paddle.eq = eq + + +def contiguous(xs: paddle.Tensor) -> paddle.Tensor: + return xs + + +if not hasattr(paddle.Tensor, 'contiguous'): + logger.warn( + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.contiguous = contiguous + + +def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + nargs = len(args) + assert (nargs <= 1) + s = paddle.shape(xs) + if nargs == 1: + return s[args[0]] + else: + return s + + +#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. +logger.warn( + "override size of paddle.Tensor " + "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" +) +paddle.Tensor.size = size + + +def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + return xs.reshape(args) + + +if not hasattr(paddle.Tensor, 'view'): + logger.warn("register user view to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view = view + + +def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: + return xs.reshape(ys.size()) + + +if not hasattr(paddle.Tensor, 'view_as'): + logger.warn( + "register user view_as to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view_as = view_as + + +def is_broadcastable(shp1, shp2): + for a, b in zip(shp1[::-1], shp2[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True + + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill'): + logger.warn( + "register user masked_fill to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill = masked_fill + + +def masked_fill_(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]) -> paddle.Tensor: + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + ret = paddle.where(mask, trues, xs) + paddle.assign(ret.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill_'): + logger.warn( + "register user masked_fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill_ = masked_fill_ + + +def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: + val = paddle.full_like(xs, value) + paddle.assign(val.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'fill_'): + logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.fill_ = fill_ + + +def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: + return paddle.tile(xs, size) + + +if not hasattr(paddle.Tensor, 'repeat'): + logger.warn( + "register user repeat to paddle.Tensor, remove this when fixed!") + paddle.Tensor.repeat = repeat + +if not hasattr(paddle.Tensor, 'softmax'): + logger.warn( + "register user softmax to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle.Tensor, 'sigmoid'): + logger.warn( + "register user sigmoid to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle.Tensor, 'relu'): + logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) + + +def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: + return x.astype(other.dtype) + + +if not hasattr(paddle.Tensor, 'type_as'): + logger.warn( + "register user type_as to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'type_as', type_as) + + +def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): #Tensor + return x.astype(args[0].dtype) + else: # Device + return x + + +if not hasattr(paddle.Tensor, 'to'): + logger.warn("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) + + +def func_float(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.float) + + +if not hasattr(paddle.Tensor, 'float'): + logger.warn("register user float to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'float', func_float) + + +def func_int(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.int) + + +if not hasattr(paddle.Tensor, 'int'): + logger.warn("register user int to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'int', func_int) + + +def tolist(x: paddle.Tensor) -> List[Any]: + return x.numpy().tolist() + + +if not hasattr(paddle.Tensor, 'tolist'): + logger.warn( + "register user tolist to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'tolist', tolist) + + +########### hcak paddle.nn ############# +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + +if not hasattr(paddle.nn, 'GLU'): + logger.warn("register user GLU to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'GLU', GLU) diff --git a/examples/v18_to_v2x/deepspeech2x/bin/test.py b/examples/v18_to_v2x/deepspeech2x/bin/test.py new file mode 100644 index 000000000..3fa0a61de --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/bin/test.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +from deepspeech2x.model import DeepSpeech2Tester as Tester + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--model_type") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/examples/v18_to_v2x/deepspeech2x/model.py b/examples/v18_to_v2x/deepspeech2x/model.py new file mode 100644 index 000000000..1fe1e2d68 --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/model.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains DeepSpeech2 and DeepSpeech2Online model.""" +import time +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +from deepspeech2x.models.ds2 import DeepSpeech2InferModel +from deepspeech2x.models.ds2 import DeepSpeech2Model +from paddle import distributed as dist +from paddle.io import DataLoader +from yacs.config import CfgNode + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.trainer import Trainer +from deepspeech.utils import error_rate +from deepspeech.utils import layer_tools +from deepspeech.utils import mp_tools +from deepspeech.utils.log import Log +#from deepspeech.utils.log import Autolog + +logger = Log(__name__).getlog() + + +class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training + start = time.time() + + # forward + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) + losses_np = { + 'train_loss': float(loss), + } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration - 1) + + @paddle.no_grad() + def valid(self): + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + self.model.eval() + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses['val_loss'].append(float(loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump['val_history_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_dump.items()) + logger.info(msg) + + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + + def setup_model(self): + config = self.config.clone() + config.defrost() + config.model.feat_size = self.train_loader.collate_fn.feature_size + #config.model.dict_size = self.train_loader.collate_fn.vocab_size + config.model.dict_size = len(self.train_loader.collate_fn.vocab_list) + config.freeze() + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") + if self.parallel: + model = paddle.DataParallel(model) + + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + grad_clip = ClipGradByGlobalNormWithLog( + config.training.global_grad_clip) + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config.training.lr, + gamma=config.training.lr_decay, + verbose=True) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config.training.weight_decay), + grad_clip=grad_clip) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + collate_fn_test = SpeechCollator.from_config(config) + + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test) + if "" in self.test_loader.collate_fn.vocab_list: + self.test_loader.collate_fn.vocab_list.remove("") + if "" in self.valid_loader.collate_fn.vocab_list: + self.valid_loader.collate_fn.vocab_list.remove("") + if "" in self.train_loader.collate_fn.vocab_list: + self.train_loader.collate_fn.vocab_list.remove("") + logger.info("Setup train/valid/test Dataloader!") + + +class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + vocab_list = self.test_loader.collate_fn.vocab_list + if "" in vocab_list: + space_id = vocab_list.index("") + vocab_list[space_id] = " " + + target_transcripts = self.ordid2token(texts, texts_len) + + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + return result_transcripts + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + self.model.eval() + cfg = self.config + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + + # self.autolog.report() + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def export(self): + if self.args.model_type == 'offline': + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + elif self.args.model_type == 'online': + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + raise Exception("wrong model type") + + infer_model.eval() + feat_dim = self.test_loader.collate_fn.feature_size + static_model = infer_model.export() + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir diff --git a/examples/v18_to_v2x/deepspeech2x/models/__init__.py b/examples/v18_to_v2x/deepspeech2x/models/__init__.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/v18_to_v2x/deepspeech2x/models/ds2/__init__.py b/examples/v18_to_v2x/deepspeech2x/models/ds2/__init__.py new file mode 100644 index 000000000..39bea5bf9 --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/models/ds2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModel +from .deepspeech2 import DeepSpeech2Model + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/examples/v18_to_v2x/deepspeech2x/models/ds2/deepspeech2.py b/examples/v18_to_v2x/deepspeech2x/models/ds2/deepspeech2.py new file mode 100644 index 000000000..f154ddb54 --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/models/ds2/deepspeech2.py @@ -0,0 +1,314 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Model""" +from typing import Optional + +import paddle +from deepspeech2x.models.ds2.rnn import RNNStack +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2.conv import ConvStack +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + + @property + def output_size(self): + return self.rnn_size * 2 + + def forward(self, audio, audio_len): + """Compute Encoder outputs + + Args: + audio (Tensor): [B, Tmax, D] + text (Tensor): [B, Umax] + audio_len (Tensor): [B] + text_len (Tensor): [B] + Returns: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + """ + # [B, T, D] -> [B, D, T] + audio = audio.transpose([0, 2, 1]) + # [B, D, T] -> [B, C=1, D, T] + x = audio.unsqueeze(1) + x_lens = audio_len + + # convolution group + x, x_lens = self.conv(x, x_lens) + x_val = x.numpy() + + # convert data from convolution feature map to sequence of vectors + #B, C, D, T = paddle.shape(x) # not work under jit + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit + x = x.reshape([0, 0, -1]) #[B, T, C*D] + + # remove padding part + x, x_lens = self.rnn(x, x_lens) #[B, T, D] + return x, x_lens + + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + assert (self.encoder.output_size == rnn_size * 2) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=blank_id, # first token is + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len = self.encoder(audio, audio_len) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + print("probs.shape", probs.shape) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + def decode_probs_split(self, probs_split, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + return self.decoder.decode_probs_split( + probs_split, vocab_list, decoding_method, lang_model_path, + beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, + num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2Model + The model built from pretrained result. + """ + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=len(dataloader.collate_fn.vocab_list), + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) + return model + + +class DeepSpeech2InferModel(DeepSpeech2Model): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, T, D] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/examples/v18_to_v2x/deepspeech2x/models/ds2/rnn.py b/examples/v18_to_v2x/deepspeech2x/models/ds2/rnn.py new file mode 100644 index 000000000..e45db7c05 --- /dev/null +++ b/examples/v18_to_v2x/deepspeech2x/models/ds2/rnn.py @@ -0,0 +1,334 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size: int, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size: int, + hidden_size: int, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states # shape [batch_size, hidden_size] + + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + bias_u, bias_r, bias_c = paddle.split( + self.bias_hh, num_or_sections=3, axis=0) + + weight_hh = paddle.transpose( + self.weight_hh, + perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size] + w_u_r_c = paddle.flatten(weight_hh) + size_u_r = self.hidden_size * 2 * self.hidden_size + w_u_r = paddle.reshape(w_u_r_c[:size_u_r], + (self.hidden_size, self.hidden_size * 2)) + w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1) + w_c = paddle.reshape(w_u_r_c[size_u_r:], + (self.hidden_size, self.hidden_size)) + + h_u = paddle.matmul( + pre_hidden, w_u, + transpose_y=False) + bias_u #shape [batch_size, hidden_size] + h_r = paddle.matmul( + pre_hidden, w_r, + transpose_y=False) + bias_r #shape [batch_size, hidden_size] + + x_u, x_r, x_c = paddle.split( + x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size] + + u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size] + r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size] + c = self._activation( + x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) + + bias_c) # [batch_size, hidden_size] + + h = (1 - u) * pre_hidden + u * c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int, share_weights: bool): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, + i_size: int, + h_size: int, + num_stacks: int, + use_gru: bool, + share_rnn_weights: bool): + super().__init__() + rnn_stacks = [] + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) + else: + rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + self.rnn_stacks = nn.LayerList(rnn_stacks) + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + # TODO(Hui Zhang): not support bool multiply + masks = masks.astype(x.dtype) + x = x.multiply(masks) + return x, x_len diff --git a/examples/v18_to_v2x/exp_aishell/.gitignore b/examples/v18_to_v2x/exp_aishell/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/v18_to_v2x/exp_aishell/conf/augmentation.json b/examples/v18_to_v2x/exp_aishell/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/v18_to_v2x/exp_aishell/conf/deepspeech2.yaml b/examples/v18_to_v2x/exp_aishell/conf/deepspeech2.yaml new file mode 100644 index 000000000..6e745e9d1 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 4333 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: cer + decoding_method: ctc_beam_search + lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 2.6 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/v18_to_v2x/exp_aishell/local/data.sh b/examples/v18_to_v2x/exp_aishell/local/data.sh new file mode 100755 index 000000000..1cde0c6ea --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/local/data.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf aishell_model_v1.8_to_v2.x.tar.gz +mv aishell_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm aishell_model_v1.8_to_v2.x.tar.gz -f + + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/aishell/aishell.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/aishell" + + if [ $? -ne 0 ]; then + echo "Prepare Aishell failed. Terminated." + exit 1 + fi + + for dataset in train dev test; do + mv data/manifest.${dataset} data/manifest.${dataset}.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="linear" \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --sample_rate=16000 \ + --use_dB_normalization=True \ + --num_samples=2000 \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for dataset in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "char" \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${dataset}.raw" \ + --output_path="data/manifest.${dataset}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + } & + done + wait +fi + +echo "Aishell data preparation done." +exit 0 diff --git a/examples/v18_to_v2x/exp_aishell/local/download_lm_ch.sh b/examples/v18_to_v2x/exp_aishell/local/download_lm_ch.sh new file mode 100755 index 000000000..ac27a9076 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/local/download_lm_ch.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm' +MD5="29e02312deb2e59b3c8686c7966d4fe3" +TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm + + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_aishell/local/download_model.sh b/examples/v18_to_v2x/exp_aishell/local/download_model.sh new file mode 100644 index 000000000..2e4873ef6 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' +MD5=4ade113c69ea291b8ce5ec6a03296659 +TARGET=./aishell_model_v1.8_to_v2.x.tar.gz + + +echo "Download Aishell model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download Aishell model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/v18_to_v2x/exp_aishell/local/test.sh b/examples/v18_to_v2x/exp_aishell/local/test.sh new file mode 100755 index 000000000..9fd0bc8d5 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/local/test.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_aishell/path.sh b/examples/v18_to_v2x/exp_aishell/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/v18_to_v2x/exp_aishell/run.sh b/examples/v18_to_v2x/exp_aishell/run.sh new file mode 100755 index 000000000..482ab2a09 --- /dev/null +++ b/examples/v18_to_v2x/exp_aishell/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=aishell_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/v18_to_v2x/exp_baidu_en8k/.gitignore b/examples/v18_to_v2x/exp_baidu_en8k/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/v18_to_v2x/exp_baidu_en8k/conf/augmentation.json b/examples/v18_to_v2x/exp_baidu_en8k/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/v18_to_v2x/exp_baidu_en8k/conf/deepspeech2.yaml b/examples/v18_to_v2x/exp_baidu_en8k/conf/deepspeech2.yaml new file mode 100644 index 000000000..fbc7466f2 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: .inf # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 1.4 + beta: 0.35 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/v18_to_v2x/exp_baidu_en8k/local/data.sh b/examples/v18_to_v2x/exp_baidu_en8k/local/data.sh new file mode 100755 index 000000000..8f9468b13 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz +mv baidu_en8k_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm baidu_en8k_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/v18_to_v2x/exp_baidu_en8k/local/download_lm_en.sh b/examples/v18_to_v2x/exp_baidu_en8k/local/download_lm_en.sh new file mode 100755 index 000000000..dc1bdf665 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_baidu_en8k/local/download_model.sh b/examples/v18_to_v2x/exp_baidu_en8k/local/download_model.sh new file mode 100644 index 000000000..6d06e3d6f --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' +MD5=fdabeb6c96963ac85d9188f0275c6a1b +TARGET=./baidu_en8k_v1.8_to_v2.x.tar.gz + + +echo "Download BaiduEn8k model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download BaiduEn8k model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/v18_to_v2x/exp_baidu_en8k/local/test.sh b/examples/v18_to_v2x/exp_baidu_en8k/local/test.sh new file mode 100755 index 000000000..b5b68c599 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/local/test.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_baidu_en8k/path.sh b/examples/v18_to_v2x/exp_baidu_en8k/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/v18_to_v2x/exp_baidu_en8k/run.sh b/examples/v18_to_v2x/exp_baidu_en8k/run.sh new file mode 100755 index 000000000..c590312d1 --- /dev/null +++ b/examples/v18_to_v2x/exp_baidu_en8k/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=baidu_en8k_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/v18_to_v2x/exp_librispeech/.gitignore b/examples/v18_to_v2x/exp_librispeech/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/v18_to_v2x/exp_librispeech/conf/augmentation.json b/examples/v18_to_v2x/exp_librispeech/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/v18_to_v2x/exp_librispeech/conf/deepspeech2.yaml b/examples/v18_to_v2x/exp_librispeech/conf/deepspeech2.yaml new file mode 100644 index 000000000..edef07972 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: 1000.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + use_gru: False + share_rnn_weights: True + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/v18_to_v2x/exp_librispeech/local/data.sh b/examples/v18_to_v2x/exp_librispeech/local/data.sh new file mode 100755 index 000000000..22a86bb2e --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf librispeech_v1.8_to_v2.x.tar.gz +mv librispeech_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm librispeech_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/v18_to_v2x/exp_librispeech/local/download_lm_en.sh b/examples/v18_to_v2x/exp_librispeech/local/download_lm_en.sh new file mode 100755 index 000000000..dc1bdf665 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_librispeech/local/download_model.sh b/examples/v18_to_v2x/exp_librispeech/local/download_model.sh new file mode 100644 index 000000000..cc6a9ec75 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' +MD5=7b0f582fe2f5a840b840e7ee52246bc5 +TARGET=./librispeech_v1.8_to_v2.x.tar.gz + + +echo "Download LibriSpeech model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download LibriSpeech model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/v18_to_v2x/exp_librispeech/local/test.sh b/examples/v18_to_v2x/exp_librispeech/local/test.sh new file mode 100755 index 000000000..b5b68c599 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/local/test.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/v18_to_v2x/exp_librispeech/path.sh b/examples/v18_to_v2x/exp_librispeech/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/v18_to_v2x/exp_librispeech/run.sh b/examples/v18_to_v2x/exp_librispeech/run.sh new file mode 100755 index 000000000..05706a428 --- /dev/null +++ b/examples/v18_to_v2x/exp_librispeech/run.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=librispeech_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi