You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/modules/encoder.py

498 lines
22 KiB

E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
"""Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.s2t.modules.activation import get_activation
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding
from paddlespeech.s2t.modules.embedding import RelPositionalEncoding
from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer
from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer
from paddlespeech.s2t.modules.mask import add_optional_chunk_mask
from paddlespeech.s2t.modules.mask import make_non_pad_mask
from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8
from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling
from paddlespeech.s2t.utils.log import Log
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
logger = Log(__name__).getlog()
__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"]
class BaseEncoder(nn.Layer):
def __init__(self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="abs_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: paddle.nn.Layer=None,
use_dynamic_left_chunk: bool=False,
max_len: int=5000):
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
"""
Args:
input_size (int): input dim, d_feature
output_size (int): dimension of attention, d_model
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of encoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
3 years ago
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after (bool): whether to concat attention layer's input
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
"""
assert check_argument_types()
super().__init__()
self._output_size = output_size
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
3 years ago
elif pos_enc_layer_type == "no_pos":
3 years ago
pos_enc_class = NoPositionalEncoding
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
subsampling_class = LinearNoSubsampling
elif input_layer == "conv2d":
subsampling_class = Conv2dSubsampling4
elif input_layer == "conv2d6":
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
else:
raise ValueError("unknown input_layer: " + input_layer)
self.global_cmvn = global_cmvn
self.embed = subsampling_class(
idim=input_size,
odim=output_size,
dropout_rate=dropout_rate,
pos_enc_class=pos_enc_class(
d_model=output_size,
dropout_rate=positional_dropout_rate,
max_len=max_len), )
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
self.normalize_before = normalize_before
self.after_norm = LayerNorm(output_size, epsilon=1e-12)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: paddle.Tensor,
xs_lens: paddle.Tensor,
decoding_chunk_size: int=0,
num_decoding_left_chunks: int=-1,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, L, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor, lens and mask
"""
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
""" Forward just one chunk
Args:
xs (paddle.Tensor): chunk input, [B=1, T, D]
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output of current input xs
paddle.Tensor: subsampling cache required for next chunk computation
List[paddle.Tensor]: encoder layers output cache required for next
chunk computation
List[paddle.Tensor]: conformer cnn cache
"""
assert xs.shape[0] == 1 # batch size must be one
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
# tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, _ = self.embed(
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
if subsampling_cache is not None:
cache_size = subsampling_cache.shape[1] #T
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.shape[1])
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = xs.shape[1]
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
else:
next_cache_start = xs.shape[1] - required_cache_size
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
r_elayers_output_cache = []
r_conformer_cnn_cache = []
for i, layer in enumerate(self.encoders):
attn_cache = None if elayers_output_cache is None else elayers_output_cache[
i]
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[
i]
xs, _, new_cnn_cache = layer(
xs,
masks,
pos_emb,
output_cache=attn_cache,
cnn_cache=cnn_cache)
r_elayers_output_cache.append(xs[:, next_cache_start:, :])
r_conformer_cnn_cache.append(new_cnn_cache)
if self.normalize_before:
xs = self.after_norm(xs)
return (xs[:, cache_size:, :], r_subsampling_cache,
r_elayers_output_cache, r_conformer_cnn_cache)
def forward_chunk_by_chunk(
self,
xs: paddle.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int=-1,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (paddle.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size.
num_left_chunks (int): decoding with num left chunks.
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
# feature stride and window for `subsampling` module
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None
conformer_cnn_cache: Optional[List[paddle.Tensor]] = None
outputs = []
offset = 0
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, subsampling_cache, elayers_output_cache,
conformer_cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
outputs.append(y)
offset += y.shape[1]
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
masks = masks.unsqueeze(1)
return ys, masks
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="abs_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False, ):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
"""
assert check_argument_types()
super().__init__(input_size, output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
attention_dropout_rate, input_layer,
pos_enc_layer_type, normalize_before, concat_after,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk)
3 years ago
self.encoders = nn.LayerList([
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
TransformerEncoderLayer(
size=output_size,
self_attn=MultiHeadedAttention(attention_heads, output_size,
attention_dropout_rate),
feed_forward=PositionwiseFeedForward(output_size, linear_units,
dropout_rate),
dropout_rate=dropout_rate,
normalize_before=normalize_before,
concat_after=concat_after) for _ in range(num_blocks)
])
def forward_one_step(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D)
3 years ago
masks (paddle.Tensor): Mask tensor. (B, T, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks, _ = e(xs, masks, output_cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="rel_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False,
positionwise_conv_kernel_size: int=1,
macaron_style: bool=True,
selfattention_layer_type: str="rel_selfattn",
activation_type: str="swish",
use_cnn_module: bool=True,
cnn_module_kernel: int=15,
causal: bool=False,
cnn_module_norm: str="batch_norm",
max_len: int=5000):
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']
"""
assert check_argument_types()
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
super().__init__(input_size, output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
attention_dropout_rate, input_layer,
pos_enc_layer_type, normalize_before, concat_after,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, max_len)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
activation = get_activation(activation_type)
# self-attention module definition
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate,
activation)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
3 years ago
self.encoders = nn.LayerList([
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
4 years ago
ConformerEncoderLayer(
size=output_size,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer(
*positionwise_layer_args) if macaron_style else None,
conv_module=convolution_layer(*convolution_layer_args)
if use_cnn_module else None,
dropout_rate=dropout_rate,
normalize_before=normalize_before,
concat_after=concat_after) for _ in range(num_blocks)
])