Merge pull request #751 from PaddlePaddle/opt

refactor optimizer and scheduler instatnce
pull/752/head
Hui Zhang 3 years ago committed by GitHub
commit 25c07e3f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,8 +31,8 @@ from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -312,30 +312,38 @@ class U2Trainer(Trainer):
scheduler_type = train_config.scheduler scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf scheduler_conf = train_config.scheduler_conf
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip) scheduler_args = {
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay) "learning_rate": optim_conf.lr,
"verbose": False,
if scheduler_type == 'expdecaylr': "warmup_steps": scheduler_conf.warmup_steps,
lr_scheduler = paddle.optimizer.lr.ExponentialDecay( "gamma": scheduler_conf.lr_decay,
learning_rate=optim_conf.lr, "d_model": model_conf.encoder_conf.output_size,
gamma=scheduler_conf.lr_decay, }
verbose=False) lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
elif scheduler_type == 'warmuplr': scheduler_args)
lr_scheduler = WarmupLR(
learning_rate=optim_conf.lr, def optimizer_args(
warmup_steps=scheduler_conf.warmup_steps, config,
verbose=False) parameters,
else: lr_scheduler=None, ):
raise ValueError(f"Not support scheduler: {scheduler_type}") train_config = config.training
optim_type = train_config.optim
if optim_type == 'adam': optim_conf = train_config.optim_conf
optimizer = paddle.optimizer.Adam( scheduler_type = train_config.scheduler
learning_rate=lr_scheduler, scheduler_conf = train_config.scheduler_conf
parameters=model.parameters(), return {
weight_decay=weight_decay, "grad_clip": train_config.global_grad_clip,
grad_clip=grad_clip) "weight_decay": optim_conf.weight_decay,
else: "learning_rate": lr_scheduler
raise ValueError(f"Not support optim: {optim_type}") if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer

@ -345,9 +345,6 @@ class U2STTrainer(Trainer):
scheduler_type = train_config.scheduler scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf scheduler_conf = train_config.scheduler_conf
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
if scheduler_type == 'expdecaylr': if scheduler_type == 'expdecaylr':
lr_scheduler = paddle.optimizer.lr.ExponentialDecay( lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=optim_conf.lr, learning_rate=optim_conf.lr,
@ -367,6 +364,8 @@ class U2STTrainer(Trainer):
else: else:
raise ValueError(f"Not support scheduler: {scheduler_type}") raise ValueError(f"Not support scheduler: {scheduler_type}")
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
if optim_type == 'adam': if optim_type == 'adam':
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler, learning_rate=lr_scheduler,

@ -1,262 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from paddle import nn
from yacs.config import CfgNode
from deepspeech.modules.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.rnn import RNNStack
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
@property
def output_size(self):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs

@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm): def __init__(self, clip_norm):
super().__init__(clip_norm) super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []

@ -0,0 +1,83 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"]
logger = Log(__name__).getlog()
OPTIMIZER_DICT = {
"sgd": "paddle.optimizer:SGD",
"momentum": "paddle.optimizer:Momentum",
"adadelta": "paddle.optimizer:Adadelta",
"adam": "paddle.optimizer:Adam",
"adamw": "paddle.optimizer:AdamW",
}
def register_optimizer(cls):
"""Register optimizer."""
alias = cls.__name__.lower()
OPTIMIZER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
def dynamic_import_optimizer(module):
"""Import Optimizer class dynamically.
Args:
module (str): module_name:class_name or alias in `OPTIMIZER_DICT`
Returns:
type: Optimizer class
"""
module_class = dynamic_import(module, OPTIMIZER_DICT)
assert issubclass(module_class,
Optimizer), f"{module} does not implement Optimizer"
return module_class
class OptimizerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
assert "parameters" in args, "parameters not in args."
assert "learning_rate" in args, "learning_rate not in args."
grad_clip = ClipGradByGlobalNormWithLog(
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
module_class = dynamic_import_optimizer(name.lower())
if weight_decay:
logger.info(f'WeightDecay: {weight_decay}')
if grad_clip:
logger.info(f'GradClip: {grad_clip}')
logger.info(
f"Optimizer: {module_class.__name__} {args['learning_rate']}")
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
return instance_class(module_class, args)

@ -11,18 +11,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
from typing import Union from typing import Union
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["WarmupLR"] __all__ = ["WarmupLR", "LRSchedulerFactory"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
SCHEDULER_DICT = {
"noam": "paddle.optimizer.lr:NoamDecay",
"expdecaylr": "paddle.optimizer.lr:ExponentialDecay",
"piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay",
}
def register_scheduler(cls):
"""Register scheduler."""
alias = cls.__name__.lower()
SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
@register_scheduler
class WarmupLR(LRScheduler): class WarmupLR(LRScheduler):
"""The WarmupLR scheduler """The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following This scheduler is almost same as NoamLR Scheduler except for following
@ -40,7 +75,8 @@ class WarmupLR(LRScheduler):
warmup_steps: Union[int, float]=25000, warmup_steps: Union[int, float]=25000,
learning_rate=1.0, learning_rate=1.0,
last_epoch=-1, last_epoch=-1,
verbose=False): verbose=False,
**kwargs):
assert check_argument_types() assert check_argument_types()
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
super().__init__(learning_rate, last_epoch, verbose) super().__init__(learning_rate, last_epoch, verbose)
@ -64,3 +100,10 @@ class WarmupLR(LRScheduler):
None None
''' '''
self.step(epoch=step) self.step(epoch=step)
class LRSchedulerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
module_class = dynamic_import_scheduler(name.lower())
return instance_class(module_class, args)

@ -0,0 +1,67 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import Any
from typing import Dict
from typing import List
from typing import Text
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'deepspeech.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError("import_path should be one of {} or "
'include ":", e.g. "deepspeech.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter by `valid_keys` and filter `val` is not None
new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]):
valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args, valid_keys)
logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args)

@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor: padding_value: float=0.0) -> paddle.Tensor:

@ -16,7 +16,7 @@ collator:
spm_model_prefix: 'data/bpe_unigram_5000' spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 64 batch_size: 32
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
@ -73,7 +73,7 @@ model:
training: training:
n_epoch: 120 n_epoch: 120
accum_grad: 2 accum_grad: 4
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 4 batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 4 batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80

@ -11,30 +11,29 @@ data:
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
collator: collator:
vocab_filepath: data/vocab.txt
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json vocab_filepath: data/vocab.txt
random_seed: 0
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_202'
specgram_type: fbank augmentation_config: conf/augmentation.json
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
stride_ms: 10.0 dither: 1.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000 target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
dither: 1.0 random_seed: 0
keep_transcription_text: False keep_transcription_text: False
batch_size: 4
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 #2 num_workers: 2
raw_wav: True # use raw_wav or kaldi feature
# network architecture # network architecture
model: model:

Loading…
Cancel
Save