diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 8ab9a26e..2b6e2433 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -184,40 +185,42 @@ class U2Trainer(Trainer): self.save(tag='init') self.lr_scheduler.step(self.iteration) - if self.parallel: + if self.parallel and hasattr(self.train_loader, 'batch_sampler'): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 140ee947..095dfe34 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -190,35 +191,37 @@ class U2Trainer(Trainer): logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index ef5938b7..8dca1654 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2_st import U2STModel from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils @@ -207,35 +208,37 @@ class U2STTrainer(Trainer): logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index a01766da..fd8f1547 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -115,7 +115,8 @@ class U2BaseModel(nn.Layer): ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, - length_normalized_loss: bool=False): + length_normalized_loss: bool=False, + **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -661,9 +662,7 @@ class U2BaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - # @jit.to_static([ - # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] - # ]) + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -830,6 +829,7 @@ class U2Model(U2BaseModel): Returns: int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ + # cmvn if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -839,11 +839,13 @@ class U2Model(U2BaseModel): else: global_cmvn = None + # input & output dim input_dim = configs['input_dim'] vocab_size = configs['output_dim'] assert input_dim != 0, input_dim assert vocab_size != 0, vocab_size + # encoder encoder_type = configs.get('encoder', 'transformer') logger.info(f"U2 Encoder type: {encoder_type}") if encoder_type == 'transformer': @@ -855,17 +857,21 @@ class U2Model(U2BaseModel): else: raise ValueError(f"not support encoder type:{encoder_type}") + # decoder decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropoutrate'], reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type='instance') + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, decoder, ctc diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index 7dae3745..6737a549 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -413,26 +413,26 @@ class U2STBaseModel(nn.Layer): best_hyps = best_hyps[:, 1:] return best_hyps - @jit.to_static + # @jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.to_static + # @jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.to_static + # @jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.to_static + # @jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ @@ -468,7 +468,7 @@ class U2STBaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.to_static + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -643,14 +643,16 @@ class U2STModel(U2STBaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropout_rate'], reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type='instance') + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, (st_decoder, decoder, ctc) else: diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 399e84e2..023a1923 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -36,16 +36,16 @@ class CTCLoss(nn.Layer): f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") # instance for norm_by_times - # batchsize for norm_by_batchsize + # batch for norm_by_batchsize # frame for norm_by_total_logits_len - assert grad_norm_type in ('instance', 'batchsize', 'frame', None) + assert grad_norm_type in ('instance', 'batch', 'frame', None) self.norm_by_times = False self.norm_by_batchsize = False self.norm_by_total_logits_len = False logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") if grad_norm_type == 'instance': self.norm_by_times = True - if grad_norm_type == 'batchsize': + if grad_norm_type == 'batch': self.norm_by_times = True if grad_norm_type == 'frame': self.norm_by_total_logits_len = True diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index 6393197a..87b36aca 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -47,9 +47,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square = layers.reduce_sum(square) sum_square_list.append(sum_square) - # debug log - logger.debug( - f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") + # debug log, not dump all since slow down train process + if i < 10: + logger.debug( + f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") # all parameters have been filterd out if len(sum_square_list) == 0: @@ -75,9 +76,10 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - # debug log - logger.debug( - f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" - ) + # debug log, not dump all since slow down train process + if i < 10: + logger.debug( + f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" + ) return params_and_grads diff --git a/doc/src/deepspeech_architecture.md b/doc/src/deepspeech_architecture.md index c4c102ba..ffa37aff 100644 --- a/doc/src/deepspeech_architecture.md +++ b/doc/src/deepspeech_architecture.md @@ -183,5 +183,3 @@ bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deeps cd examples/aishell/s0 bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml ``` - - diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 3e606788..6f8ae135 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 4b1430c5..a4248459 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -71,6 +71,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/chunk_conformer.yaml b/examples/librispeech/s1/conf/chunk_conformer.yaml index 0de1aefe..92db20f6 100644 --- a/examples/librispeech/s1/conf/chunk_conformer.yaml +++ b/examples/librispeech/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index f782a037..e0bc3135 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 6d825f05..78be249c 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index bc2ec606..4aa7b915 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -33,7 +33,7 @@ collator: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 2 + num_workers: 0 # network architecture @@ -67,6 +67,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index c57946a6..17a9e28d 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -20,7 +20,7 @@ echo "using ${device}..." mkdir -p exp seed=10086 -if [ ${seed} != 0]; then +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi diff --git a/examples/librispeech/s2/conf/chunk_conformer.yaml b/examples/librispeech/s2/conf/chunk_conformer.yaml index 0de1aefe..92db20f6 100644 --- a/examples/librispeech/s2/conf/chunk_conformer.yaml +++ b/examples/librispeech/s2/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/chunk_transformer.yaml b/examples/librispeech/s2/conf/chunk_transformer.yaml index f782a037..e0bc3135 100644 --- a/examples/librispeech/s2/conf/chunk_transformer.yaml +++ b/examples/librispeech/s2/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/conformer.yaml b/examples/librispeech/s2/conf/conformer.yaml index 955b6108..9a727413 100644 --- a/examples/librispeech/s2/conf/conformer.yaml +++ b/examples/librispeech/s2/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index 4c60913e..edf5b81d 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -58,6 +58,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/ted_en_zh/t0/conf/transformer.yaml b/examples/ted_en_zh/t0/conf/transformer.yaml index 755e0446..1aad86d2 100644 --- a/examples/ted_en_zh/t0/conf/transformer.yaml +++ b/examples/ted_en_zh/t0/conf/transformer.yaml @@ -68,6 +68,8 @@ model: model_conf: asr_weight: 0.0 ctc_weight: 0.0 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml index bc1f8890..0144c40d 100644 --- a/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml +++ b/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml @@ -68,6 +68,8 @@ model: model_conf: asr_weight: 0.5 ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/timit/s1/conf/transformer.yaml b/examples/timit/s1/conf/transformer.yaml index eb191d0b..c3b51996 100644 --- a/examples/timit/s1/conf/transformer.yaml +++ b/examples/timit/s1/conf/transformer.yaml @@ -66,6 +66,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 96da3d9f..be2e82f9 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index 1adb91c4..93439a85 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index b40e77e3..9bb67c44 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index fd5adbde..5127e8c6 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -66,6 +66,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false