model init from config

pull/578/head
Hui Zhang 5 years ago
parent a7244593b9
commit 090e794723

@ -3,7 +3,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "downtown-invalid", "id": "medieval-monday",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -213,27 +213,6 @@
} }
], ],
"source": [ "source": [
"# batch_reader = create_dataloader(\n",
"# manifest_path=args.infer_manifest,\n",
"# vocab_filepath=args.vocab_path,\n",
"# mean_std_filepath=args.mean_std_path,\n",
"# augmentation_config='{}',\n",
"# #max_duration=float('inf'),\n",
"# max_duration=27.0,\n",
"# min_duration=0.0,\n",
"# stride_ms=10.0,\n",
"# window_ms=20.0,\n",
"# max_freq=None,\n",
"# specgram_type=args.specgram_type,\n",
"# use_dB_normalization=True,\n",
"# random_seed=0,\n",
"# keep_transcription_text=True,\n",
"# is_training=False,\n",
"# batch_size=args.num_samples,\n",
"# sortagrad=True,\n",
"# shuffle_method=None,\n",
"# dist=False)\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n", "from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n", "from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
@ -375,7 +354,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 6,
"id": "minus-modern", "id": "minus-modern",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -391,8 +370,6 @@
" [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n", " [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n",
"test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n", "test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n",
"test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n", "test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [12, 13, 11, 22, 23])\n", " [12, 13, 11, 22, 23])\n",
"audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", "audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
@ -434,7 +411,9 @@
" ...,\n", " ...,\n",
" [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n", " [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n",
" [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n", " [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n",
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n" " [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n"
] ]
} }
], ],
@ -472,16 +451,16 @@
" print('test:', text)\n", " print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n", " print('test len:', text_len)\n",
" print('audio:', audio)\n", " print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break" " break"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "chronic-diagram", "id": "competitive-mounting",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
__all__ = ['U2TransformerModel', "U2ConformerModel"] __all__ = ['U2TransformerModel', "U2ConformerModel"]
class U2Model(nn.Module): class U2BaseModel(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model""" """CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self, def __init__(self,
@ -635,28 +635,9 @@ class U2Model(nn.Module):
return decoder_out return decoder_out
class U2TransformerModel(U2Model): class U2Model(U2BaseModel):
def __init__(self, configs: dict): def __init__(self, configs: dict):
if configs['cmvn_file'] is not None: vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'transformer')
assert encoder_type == 'transformer'
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
super().__init__( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
@ -665,9 +646,19 @@ class U2TransformerModel(U2Model):
ctc=ctc, ctc=ctc,
**configs['model_conf']) **configs['model_conf'])
@classmethod
def _init_from_config(cls, configs: dict):
"""init sub module for model.
class U2ConformerModel(U2Model): Args:
def __init__(self, configs: dict): configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
@ -679,19 +670,46 @@ class U2ConformerModel(U2Model):
input_dim = configs['input_dim'] input_dim = configs['input_dim']
vocab_size = configs['output_dim'] vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer') encoder_type = configs.get('encoder', 'transformer')
assert encoder_type == 'conformer' logger.info(f"U2 Encoder type: {encoder_type}")
encoder = ConformerEncoder( if encoder_type == 'transformer':
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError("not support encoder type:{encoder_type}")
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size()) ctc = CTCDecoder(vocab_size, encoder.output_size())
return vocab_size, encoder, decoder, ctc
super().__init__( @classmethod
vocab_size=vocab_size, def from_pretrained(cls, dataset, config, checkpoint_path):
encoder=encoder, """Build a DeepSpeech2Model model from a pretrained model.
decoder=decoder,
ctc=ctc, Args:
**configs['model_conf']) dataset (paddle.io.Dataset): [description]
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
model = cls(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model

@ -1,3 +1,38 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'bpe_unigram_200'
mean_std_filepath: data/mean_std.npz
augmentation_config: conf/augmentation.config
batch_size: 4
max_input_len: 27.0
min_input_len: 0.0
max_output_len: .INF
min_output_len: 0.0
max_output_input_ratio: .INF
min_output_input_ratio: 0.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
# network architecture # network architecture
# encoder related # encoder related
encoder: conformer encoder: conformer
@ -34,9 +69,6 @@ model_conf:
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
# use raw_wav or kaldi feature
raw_wav: true
# feature extraction # feature extraction
collate_conf: collate_conf:
# waveform level config # waveform level config

@ -18,8 +18,7 @@ import unittest
import numpy as np import numpy as np
from yacs.config import CfgNode as CN from yacs.config import CfgNode as CN
from deepspeech.models.u2 import U2TransformerModel from deepspeech.models.u2 import U2Model
from deepspeech.models.u2 import U2ConformerModel
from deepspeech.utils.layer_tools import summary from deepspeech.utils.layer_tools import summary
@ -84,7 +83,7 @@ class TestU2Model(unittest.TestCase):
cfg.cmvn_file = None cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz' cfg.cmvn_file_type = 'npz'
cfg.freeze() cfg.freeze()
model = U2TransformerModel(cfg) model = U2Model(cfg)
summary(model, None) summary(model, None)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len, total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len) self.text, self.text_len)
@ -136,7 +135,7 @@ class TestU2Model(unittest.TestCase):
cfg.cmvn_file = None cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz' cfg.cmvn_file_type = 'npz'
cfg.freeze() cfg.freeze()
model = U2ConformerModel(cfg) model = U2Model(cfg)
summary(model, None) summary(model, None)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len, total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len) self.text, self.text_len)

Loading…
Cancel
Save