diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py index a1607d583..f8764fde3 100644 --- a/deepspeech/exps/deepspeech2/bin/export.py +++ b/deepspeech/exps/deepspeech2/bin/export.py @@ -30,11 +30,15 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument("--model_type") args = parser.parse_args() + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) print_arguments(args) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/bin/test.py b/deepspeech/exps/deepspeech2/bin/test.py index f4edf08a8..376e18e38 100644 --- a/deepspeech/exps/deepspeech2/bin/test.py +++ b/deepspeech/exps/deepspeech2/bin/test.py @@ -30,11 +30,15 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument("--model_type") args = parser.parse_args() print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 5e5c1e2a4..69ff043a0 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -35,11 +35,15 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument("--model_type") args = parser.parse_args() + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index a851e1f72..38b7d0e4d 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.ds2 import DeepSpeech2Model - -_C = CfgNode() - -_C.data = ManifestDataset.params() - -_C.collator = SpeechCollator.params() - -_C.model = DeepSpeech2Model.params() - -_C.training = DeepSpeech2Trainer.params() - -_C.decoding = DeepSpeech2Tester.params() - - -def get_cfg_defaults(): +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline + + +def get_cfg_defaults(model_type='offline'): + _C = CfgNode() + _C.data = ManifestDataset.params() + _C.collator = SpeechCollator.params() + _C.training = DeepSpeech2Trainer.params() + _C.decoding = DeepSpeech2Tester.params() + if model_type == 'offline': + _C.model = DeepSpeech2Model.params() + else: + _C.model = DeepSpeech2ModelOnline.params() """Get a yacs CfgNode object with default values for my_project.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 2f84b686c..dfd812419 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Contains DeepSpeech2 model.""" +"""Contains DeepSpeech2 and DeepSpeech2Online model.""" import time from collections import defaultdict from pathlib import Path @@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2Model +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate @@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer): return total_loss, num_seen_utts def setup_model(self): - config = self.config - model = DeepSpeech2Model( - feat_size=self.train_loader.collate_fn.feature_size, - dict_size=self.train_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - + config = self.config.clone() + config.defrost() + assert (self.train_loader.collate_fn.feature_size == + self.test_loader.collate_fn.feature_size) + assert (self.train_loader.collate_fn.vocab_size == + self.test_loader.collate_fn.vocab_size) + config.model.feat_size = self.train_loader.collate_fn.feature_size + config.model.dict_size = self.train_loader.collate_fn.vocab_size + config.freeze() + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") if self.parallel: model = paddle.DataParallel(model) @@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer): config.data.manifest = config.data.dev_manifest dev_dataset = ManifestDataset.from_config(config) + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, @@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer): config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) + + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + collate_fn_test = SpeechCollator.from_config(config) + self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer): shuffle=False, drop_last=False, collate_fn=collate_fn_dev) - logger.info("Setup train/valid Dataloader!") + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test) + logger.info("Setup train/valid/test Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): @@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): exit(-1) def export(self): - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) + if self.args.model_type == 'offline': + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + elif self.args.model_type == 'online': + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + raise Exception("wrong model type") + infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size - static_model = paddle.jit.to_static( - infer_model, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, feat_dim], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) + static_model = infer_model.export() logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) @@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.iteration = 0 self.epoch = 0 - def setup_model(self): - config = self.config - model = DeepSpeech2Model( - feat_size=self.test_loader.collate_fn.feature_size, - dict_size=self.test_loader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - self.model = model - logger.info("Setup model!") - - def setup_dataloader(self): - config = self.config.clone() - config.defrost() - # return raw text - - config.data.manifest = config.data.test_manifest - # filter test examples, will cause less examples, but no mismatch with training - # and can use large batch size , save training time, so filter test egs now. - # config.data.min_input_len = 0.0 # second - # config.data.max_input_len = float('inf') # second - # config.data.min_output_len = 0.0 # tokens - # config.data.max_output_len = float('inf') # tokens - # config.data.min_output_input_ratio = 0.00 - # config.data.max_output_input_ratio = float('inf') - test_dataset = ManifestDataset.from_config(config) - - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - # return text ord id - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=SpeechCollator.from_config(config)) - logger.info("Setup test Dataloader!") - def setup_output_dir(self): """Create a directory used for output. """ diff --git a/deepspeech/io/sampler.py b/deepspeech/io/sampler.py index 763a3781e..3b2ef757d 100644 --- a/deepspeech/io/sampler.py +++ b/deepspeech/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 8d737e800..1ffd797b4 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer): layer_tools.summary(model) return model + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights) + return model + class DeepSpeech2InferModel(DeepSpeech2Model): def __init__(self, @@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model): eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) return probs + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/deepspeech/models/ds2_online/__init__.py b/deepspeech/models/ds2_online/__init__.py new file mode 100644 index 000000000..255000eeb --- /dev/null +++ b/deepspeech/models/ds2_online/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModelOnline +from .deepspeech2 import DeepSpeech2ModelOnline + +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py new file mode 100644 index 000000000..1af69e28c --- /dev/null +++ b/deepspeech/models/ds2_online/conv.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + +from deepspeech.modules.embedding import PositionalEncoding +from deepspeech.modules.subsampling import Conv2dSubsampling4 + + +class Conv2dSubsampling4Online(Conv2dSubsampling4): + def __init__(self, idim: int, odim: int, dropout_rate: float): + super().__init__(idim, odim, dropout_rate, None) + self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim + self.receptive_field_length = 2 * ( + 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 + + def forward(self, x: paddle.Tensor, + x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + #b, c, t, f = paddle.shape(x) #not work under jit + x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) + x_len = ((x_len - 1) // 2 - 1) // 2 + return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py new file mode 100644 index 000000000..3083e4b2a --- /dev/null +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Online Model""" +from typing import Optional + +import paddle +import paddle.nn.functional as F +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + self.num_rnn_layers = num_rnn_layers + self.num_fc_layers = num_fc_layers + self.rnn_direction = rnn_direction + self.fc_layers_size_list = fc_layers_size_list + self.use_gru = use_gru + self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) + + self.output_dim = self.conv.output_dim + + i_size = self.conv.output_dim + self.rnn = nn.LayerList() + self.layernorm_list = nn.LayerList() + self.fc_layers_list = nn.LayerList() + if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': + layernorm_size = 2 * rnn_size + elif rnn_direction == 'forward': + layernorm_size = rnn_size + else: + raise Exception("Wrong rnn direction") + for i in range(0, num_rnn_layers): + if i == 0: + rnn_input_size = i_size + else: + rnn_input_size = layernorm_size + if use_gru == True: + self.rnn.append( + nn.GRU( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + else: + self.rnn.append( + nn.LSTM( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) + self.output_dim = layernorm_size + + fc_input_size = layernorm_size + for i in range(self.num_fc_layers): + self.fc_layers_list.append( + nn.Linear(fc_input_size, fc_layers_size_list[i])) + fc_input_size = fc_layers_size_list[i] + self.output_dim = fc_layers_size_list[i] + + @property + def output_size(self): + return self.output_dim + + def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): + """Compute Encoder outputs + + Args: + x (Tensor): [B, feature_size, D] + x_lens (Tensor): [B] + init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + Returns: + x (Tensor): encoder outputs, [B, size, D] + x_lens (Tensor): encoder length, [B] + final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + """ + if init_state_h_box is not None: + init_state_list = None + + if self.use_gru == True: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_list = init_state_h_list + else: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_c_list = paddle.split( + init_state_c_box, self.num_rnn_layers, axis=0) + init_state_list = [(init_state_h_list[i], init_state_c_list[i]) + for i in range(self.num_rnn_layers)] + else: + init_state_list = [None] * self.num_rnn_layers + + x, x_lens = self.conv(x, x_lens) + final_chunk_state_list = [] + for i in range(0, self.num_rnn_layers): + x, final_state = self.rnn[i](x, init_state_list[i], + x_lens) #[B, T, D] + final_chunk_state_list.append(final_state) + x = self.layernorm_list[i](x) + + for i in range(self.num_fc_layers): + x = self.fc_layers_list[i](x) + x = F.relu(x) + + if self.use_gru == True: + final_chunk_state_h_box = paddle.concat( + final_chunk_state_list, axis=0) + final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) + else: + final_chunk_state_h_list = [ + final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_chunk_state_c_list = [ + final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_chunk_state_h_box = paddle.concat( + final_chunk_state_h_list, axis=0) + final_chunk_state_c_box = paddle.concat( + final_chunk_state_c_list, axis=0) + + return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box + + def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): + """Compute Encoder outputs + + Args: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + decoder_chunk_size: The chunk size of decoder + Returns: + eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + """ + subsampling_rate = self.conv.subsampling_rate + receptive_field_length = self.conv.receptive_field_length + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + max_len = x.shape[1] + assert (chunk_size <= max_len) + + eouts_chunk_list = [] + eouts_chunk_lens_list = [] + + padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride + padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) + padded_x = paddle.concat([x, padding], axis=1) + num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + chunk_state_h_box = None + chunk_state_c_box = None + final_state_h_box = None + final_state_c_box = None + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + + x_len_left = paddle.where(x_lens - i * chunk_stride < 0, + paddle.zeros_like(x_lens), + x_lens - i * chunk_stride) + x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size + x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, + x_len_left, x_chunk_len_tmp) + + eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( + x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) + + eouts_chunk_list.append(eouts_chunk) + eouts_chunk_lens_list.append(eouts_chunk_lens) + final_state_h_box = chunk_state_h_box + final_state_c_box = chunk_state_c_box + return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box + + +class DeepSpeech2ModelOnline(nn.Layer): + """The DeepSpeech2 network structure for online. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param num_fc_layers: Number of stacking FC layers. + :type num_fc_layers: int + :param fc_layers_size_list: The list of FC layer sizes. + :type fc_layers_size_list: [int,] + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=4, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True, #Use gru if set True. Use simple rnn if set False. + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, + rnn_size=rnn_size, + use_gru=use_gru) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=0, # first token is + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) + probs = self.decoder.softmax(eouts) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2ModelOnline + The model built from pretrained result. + """ + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + rnn_direction=config.model.rnn_direction, + num_fc_layers=config.model.num_fc_layers, + fc_layers_size_list=config.model.fc_layers_size_list, + use_gru=config.model.use_gru) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2ModelOnline from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2ModelOnline + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, + use_gru=config.use_gru) + return model + + +class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, + use_gru=use_gru) + + def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, + chunk_state_c_box): + eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( + audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) + probs_chunk = self.decoder.softmax(eouts_chunk) + return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, + self.encoder.feat_size], #[B, chunk_size, feat_dim] + dtype='float32'), + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') + ]) + return static_model diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 5aa2fd8ea..40fa7b00a 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling4 object. - + Args: idim (int): Input dimension. odim (int): Output dimension. @@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling6 object. - + Args: idim (int): Input dimension. odim (int): Output dimension. @@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling8 object. - + Args: idim (int): Input dimension. odim (int): Output dimension. diff --git a/env.sh b/env.sh index c5acd0112..9d22259df 100644 --- a/env.sh +++ b/env.sh @@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 +export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml new file mode 100644 index 000000000..33030a523 --- /dev/null +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 32 # one gpu + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear #linear, mfcc, fbank + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + rnn_direction: forward # [forward, bidirect] + num_fc_layers: 1 + fc_layers_size_list: 512, + use_gru: True + +training: + n_epoch: 50 + lr: 2e-3 + lr_decay: 0.83 # 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: cer + decoding_method: ctc_beam_search + lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 1.9 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 + cutoff_top_n: 40 + num_proc_bsearch: 10 diff --git a/examples/aishell/s0/local/export.sh b/examples/aishell/s0/local/export.sh index f99a15bad..2e09e5f5e 100755 --- a/examples/aishell/s0/local/export.sh +++ b/examples/aishell/s0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,6 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 +model_type=$4 device=gpu if [ ${ngpu} == 0 ];then @@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index fd9cb5661..9fd0bc8d5 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi @@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_ch.sh @@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \ --nproc 1 \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index f6bd2c983..c6a631800 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" exit -1 fi @@ -10,6 +10,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +model_type=$3 device=gpu if [ ${ngpu} == 0 ];then @@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index c9708dcc9..7cd63999c 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -7,6 +7,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -21,7 +22,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -31,10 +32,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml new file mode 100644 index 000000000..2e4aed40a --- /dev/null +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev-clean + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 20 + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 20.0 + delta_delta: False + dither: 1.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + rnn_direction: forward + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: False + +training: + n_epoch: 50 + lr: 1e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 5.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 1.9 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/librispeech/s0/local/export.sh b/examples/librispeech/s0/local/export.sh index f99a15bad..2e09e5f5e 100755 --- a/examples/librispeech/s0/local/export.sh +++ b/examples/librispeech/s0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,6 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 +model_type=$4 device=gpu if [ ${ngpu} == 0 ];then @@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/librispeech/s0/local/test.sh b/examples/librispeech/s0/local/test.sh index 16a5e9ef0..b5b68c599 100755 --- a/examples/librispeech/s0/local/test.sh +++ b/examples/librispeech/s0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi @@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_en.sh @@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \ --nproc 1 \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index f3eb98daf..039b9cea4 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" exit -1 fi @@ -10,6 +10,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +model_type=$3 device=gpu if [ ${ngpu} == 0 ];then @@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index 6553e073d..c7902a56a 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -6,6 +6,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=30 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} @@ -19,7 +20,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -29,10 +30,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml new file mode 100644 index 000000000..333c2b9a9 --- /dev/null +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,69 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + min_input_len: 0.0 + max_input_len: 27.0 + min_output_len: 0.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + + +collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 4 + +model: + num_conv_layers: 2 + num_rnn_layers: 4 + rnn_layer_size: 2048 + rnn_direction: forward + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: True + +training: + n_epoch: 10 + lr: 1e-5 + lr_decay: 1.0 + weight_decay: 1e-06 + global_grad_clip: 5.0 + log_interval: 1 + checkpoint: + kbest_n: 3 + latest_n: 2 + + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/tiny/s0/local/export.sh b/examples/tiny/s0/local/export.sh index f99a15bad..2e09e5f5e 100755 --- a/examples/tiny/s0/local/export.sh +++ b/examples/tiny/s0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,6 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 +model_type=$4 device=gpu if [ ${ngpu} == 0 ];then @@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/tiny/s0/local/test.sh b/examples/tiny/s0/local/test.sh index 16a5e9ef0..b5b68c599 100755 --- a/examples/tiny/s0/local/test.sh +++ b/examples/tiny/s0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi @@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_en.sh @@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \ --nproc 1 \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index f6bd2c983..c6a631800 100755 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" exit -1 fi @@ -10,6 +10,7 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +model_type=$3 device=gpu if [ ${ngpu} == 0 ];then @@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index d7e153e8d..408b28fd0 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -7,6 +7,7 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -21,7 +22,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -31,10 +32,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py index 1776736f5..00df8195b 100644 --- a/tests/deepspeech2_model_test.py +++ b/tests/deepspeech2_model_test.py @@ -16,7 +16,7 @@ import unittest import numpy as np import paddle -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model class TestDeepSpeech2Model(unittest.TestCase): diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py new file mode 100644 index 000000000..87f048870 --- /dev/null +++ b/tests/deepspeech2_online_model_test.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline + + +class TestDeepSpeech2ModelOnline(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + + self.batch_size = 2 + self.feat_dim = 161 + max_len = 210 + + # (B, T, D) + audio = np.random.randn(self.batch_size, max_len, self.feat_dim) + audio_len = np.random.randint(max_len, size=self.batch_size) + audio_len[-1] = max_len + # (B, U) + text = np.array([[1, 2], [1, 2]]) + text_len = np.array([2] * self.batch_size) + + self.audio = paddle.to_tensor(audio, dtype='float32') + self.audio_len = paddle.to_tensor(audio_len, dtype='int64') + self.text = paddle.to_tensor(text, dtype='int32') + self.text_len = paddle.to_tensor(text_len, dtype='int64') + + def test_ds2_1(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_2(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_3(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_4(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_5(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_6(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + rnn_direction='bidirect', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_7(self): + use_gru = False + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=1, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=use_gru) + model.eval() + paddle.device.set_device("cpu") + de_ch_size = 8 + + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( + self.audio, self.audio_len, de_ch_size) + eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) + eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) + decode_max_len = eouts.shape[1] + eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + if use_gru == False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) + + def test_ds2_8(self): + use_gru = True + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=1, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=use_gru) + model.eval() + paddle.device.set_device("cpu") + de_ch_size = 8 + + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( + self.audio, self.audio_len, de_ch_size) + eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) + eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) + decode_max_len = eouts.shape[1] + eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + if use_gru == False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) + + +if __name__ == '__main__': + unittest.main()