diff --git a/.notebook/dataloader_with_tokens_tokenids.ipynb b/.notebook/dataloader_with_tokens_tokenids.ipynb index 4f4c51b71..30d492eba 100644 --- a/.notebook/dataloader_with_tokens_tokenids.ipynb +++ b/.notebook/dataloader_with_tokens_tokenids.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "downtown-invalid", + "id": "medieval-monday", "metadata": {}, "outputs": [ { @@ -213,27 +213,6 @@ } ], "source": [ - "# batch_reader = create_dataloader(\n", - "# manifest_path=args.infer_manifest,\n", - "# vocab_filepath=args.vocab_path,\n", - "# mean_std_filepath=args.mean_std_path,\n", - "# augmentation_config='{}',\n", - "# #max_duration=float('inf'),\n", - "# max_duration=27.0,\n", - "# min_duration=0.0,\n", - "# stride_ms=10.0,\n", - "# window_ms=20.0,\n", - "# max_freq=None,\n", - "# specgram_type=args.specgram_type,\n", - "# use_dB_normalization=True,\n", - "# random_seed=0,\n", - "# keep_transcription_text=True,\n", - "# is_training=False,\n", - "# batch_size=args.num_samples,\n", - "# sortagrad=True,\n", - "# shuffle_method=None,\n", - "# dist=False)\n", - "\n", "from deepspeech.frontend.utility import read_manifest\n", "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", "from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n", @@ -375,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "minus-modern", "metadata": {}, "outputs": [ @@ -391,8 +370,6 @@ " [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n", "test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n", "test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 173, 184, 190, 203])\n", "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [12, 13, 11, 22, 23])\n", "audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", @@ -434,7 +411,9 @@ " ...,\n", " [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n", " [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n", - " [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n" + " [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n", + "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", + " [163, 173, 184, 190, 203])\n" ] } ], @@ -472,16 +451,16 @@ " print('test:', text)\n", " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len:', audio_len)\n", " print('test len:', text_len)\n", " print('audio:', audio)\n", + " print('audio len:', audio_len)\n", " break" ] }, { "cell_type": "code", "execution_count": null, - "id": "chronic-diagram", + "id": "competitive-mounting", "metadata": {}, "outputs": [], "source": [] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 8fcc9fca6..abeabb76c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) __all__ = ['U2TransformerModel', "U2ConformerModel"] -class U2Model(nn.Module): +class U2BaseModel(nn.Module): """CTC-Attention hybrid Encoder-Decoder model""" def __init__(self, @@ -635,28 +635,9 @@ class U2Model(nn.Module): return decoder_out -class U2TransformerModel(U2Model): +class U2Model(U2BaseModel): def __init__(self, configs: dict): - if configs['cmvn_file'] is not None: - mean, istd = load_cmvn(configs['cmvn_file'], - configs['cmvn_file_type']) - global_cmvn = GlobalCMVN( - paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) - else: - global_cmvn = None - - input_dim = configs['input_dim'] - vocab_size = configs['output_dim'] - - encoder_type = configs.get('encoder', 'transformer') - assert encoder_type == 'transformer' - encoder = TransformerEncoder( - input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) - - decoder = TransformerDecoder(vocab_size, - encoder.output_size(), - **configs['decoder_conf']) - ctc = CTCDecoder(vocab_size, encoder.output_size()) + vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) super().__init__( vocab_size=vocab_size, @@ -665,9 +646,19 @@ class U2TransformerModel(U2Model): ctc=ctc, **configs['model_conf']) + @classmethod + def _init_from_config(cls, configs: dict): + """init sub module for model. -class U2ConformerModel(U2Model): - def __init__(self, configs: dict): + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + """ if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -679,19 +670,46 @@ class U2ConformerModel(U2Model): input_dim = configs['input_dim'] vocab_size = configs['output_dim'] - encoder_type = configs.get('encoder', 'conformer') - assert encoder_type == 'conformer' - encoder = ConformerEncoder( - input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + encoder_type = configs.get('encoder', 'transformer') + logger.info(f"U2 Encoder type: {encoder_type}") + if encoder_type == 'transformer': + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + raise ValueError("not support encoder type:{encoder_type}") decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) ctc = CTCDecoder(vocab_size, encoder.output_size()) + return vocab_size, encoder, decoder, ctc - super().__init__( - vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf']) + @classmethod + def from_pretrained(cls, dataset, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataset (paddle.io.Dataset): [description] + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + + vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) + + model = cls(vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf']) + + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index aac997938..9582219fd 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -1,3 +1,38 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'bpe_unigram_200' + mean_std_filepath: data/mean_std.npz + augmentation_config: conf/augmentation.config + batch_size: 4 + max_input_len: 27.0 + min_input_len: 0.0 + max_output_len: .INF + min_output_len: 0.0 + max_output_input_ratio: .INF + min_output_input_ratio: 0.0 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 20.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + + # network architecture # encoder related encoder: conformer @@ -34,9 +69,6 @@ model_conf: lsm_weight: 0.1 # label smoothing option length_normalized_loss: false -# use raw_wav or kaldi feature -raw_wav: true - # feature extraction collate_conf: # waveform level config diff --git a/tests/u2_model_test.py b/tests/u2_model_test.py index 10f413e3a..a10cdc93b 100644 --- a/tests/u2_model_test.py +++ b/tests/u2_model_test.py @@ -18,8 +18,7 @@ import unittest import numpy as np from yacs.config import CfgNode as CN -from deepspeech.models.u2 import U2TransformerModel -from deepspeech.models.u2 import U2ConformerModel +from deepspeech.models.u2 import U2Model from deepspeech.utils.layer_tools import summary @@ -84,7 +83,7 @@ class TestU2Model(unittest.TestCase): cfg.cmvn_file = None cfg.cmvn_file_type = 'npz' cfg.freeze() - model = U2TransformerModel(cfg) + model = U2Model(cfg) summary(model, None) total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len, self.text, self.text_len) @@ -136,7 +135,7 @@ class TestU2Model(unittest.TestCase): cfg.cmvn_file = None cfg.cmvn_file_type = 'npz' cfg.freeze() - model = U2ConformerModel(cfg) + model = U2Model(cfg) summary(model, None) total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len, self.text, self.text_len)