model init from config

pull/578/head
Hui Zhang 5 years ago
parent a7244593b9
commit 090e794723

@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "downtown-invalid",
"id": "medieval-monday",
"metadata": {},
"outputs": [
{
@ -213,27 +213,6 @@
}
],
"source": [
"# batch_reader = create_dataloader(\n",
"# manifest_path=args.infer_manifest,\n",
"# vocab_filepath=args.vocab_path,\n",
"# mean_std_filepath=args.mean_std_path,\n",
"# augmentation_config='{}',\n",
"# #max_duration=float('inf'),\n",
"# max_duration=27.0,\n",
"# min_duration=0.0,\n",
"# stride_ms=10.0,\n",
"# window_ms=20.0,\n",
"# max_freq=None,\n",
"# specgram_type=args.specgram_type,\n",
"# use_dB_normalization=True,\n",
"# random_seed=0,\n",
"# keep_transcription_text=True,\n",
"# is_training=False,\n",
"# batch_size=args.num_samples,\n",
"# sortagrad=True,\n",
"# shuffle_method=None,\n",
"# dist=False)\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
@ -375,7 +354,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "minus-modern",
"metadata": {},
"outputs": [
@ -391,8 +370,6 @@
" [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n",
"test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n",
"test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [12, 13, 11, 22, 23])\n",
"audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
@ -434,7 +411,9 @@
" ...,\n",
" [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n",
" [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n",
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n"
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n"
]
}
],
@ -472,16 +451,16 @@
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "chronic-diagram",
"id": "competitive-mounting",
"metadata": {},
"outputs": [],
"source": []

@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
__all__ = ['U2TransformerModel', "U2ConformerModel"]
class U2Model(nn.Module):
class U2BaseModel(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
@ -635,28 +635,9 @@ class U2Model(nn.Module):
return decoder_out
class U2TransformerModel(U2Model):
class U2Model(U2BaseModel):
def __init__(self, configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'transformer')
assert encoder_type == 'transformer'
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
super().__init__(
vocab_size=vocab_size,
@ -665,9 +646,19 @@ class U2TransformerModel(U2Model):
ctc=ctc,
**configs['model_conf'])
@classmethod
def _init_from_config(cls, configs: dict):
"""init sub module for model.
class U2ConformerModel(U2Model):
def __init__(self, configs: dict):
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
@ -679,19 +670,46 @@ class U2ConformerModel(U2Model):
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer')
assert encoder_type == 'conformer'
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer':
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError("not support encoder type:{encoder_type}")
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
return vocab_size, encoder, decoder, ctc
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataset (paddle.io.Dataset): [description]
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
model = cls(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model

@ -1,3 +1,38 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'bpe_unigram_200'
mean_std_filepath: data/mean_std.npz
augmentation_config: conf/augmentation.config
batch_size: 4
max_input_len: 27.0
min_input_len: 0.0
max_output_len: .INF
min_output_len: 0.0
max_output_input_ratio: .INF
min_output_input_ratio: 0.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
# network architecture
# encoder related
encoder: conformer
@ -34,9 +69,6 @@ model_conf:
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# use raw_wav or kaldi feature
raw_wav: true
# feature extraction
collate_conf:
# waveform level config

@ -18,8 +18,7 @@ import unittest
import numpy as np
from yacs.config import CfgNode as CN
from deepspeech.models.u2 import U2TransformerModel
from deepspeech.models.u2 import U2ConformerModel
from deepspeech.models.u2 import U2Model
from deepspeech.utils.layer_tools import summary
@ -84,7 +83,7 @@ class TestU2Model(unittest.TestCase):
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2TransformerModel(cfg)
model = U2Model(cfg)
summary(model, None)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len)
@ -136,7 +135,7 @@ class TestU2Model(unittest.TestCase):
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2ConformerModel(cfg)
model = U2Model(cfg)
summary(model, None)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len)

Loading…
Cancel
Save