fix for kaldi

pull/768/head
Hui Zhang 3 years ago
parent cd34e733a4
commit ab23eb5710

@ -407,42 +407,3 @@ class GLU(nn.Layer):
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
########### hcak paddle.jit #############
if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static)

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_align()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_arguments(
'--model-name',
type=str,
default='u2',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -0,0 +1,48 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_export()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -0,0 +1,55 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')

@ -0,0 +1,69 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for U2 model."""
import cProfile
import os
from paddle import distributed as dist
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_alias = {
"u2": "deepspeech.exps.u2.model:U2Trainer",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
}
def main_sp(config, args):
trainer_cls = dynamic_import(args.model_name, model_alias)
exp = trainer_cls(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.device == "gpu" and args.nprocs > 1:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
args = parser.parse_args()
print_arguments(args, globals())
config = CfgNode()
config.set_new_allowed(True)
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))

@ -0,0 +1,642 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
_C = CfgNode()
_C.model = U2Model.params()
_C.training = U2Trainer.params()
_C.decoding = U2Tester.params()
config = _C.clone()
config.set_new_allowed(True)
return config
class U2Trainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
n_epoch=50, # train epochs
log_interval=100, # steps
accum_grad=1, # accum grad by # steps
checkpoint=dict(
kbest_n=50,
latest_n=5, ), ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss:
losses_np['att_loss'] = float(attention_loss)
if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss)
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
iteration_time = time.time() - start
if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v,
self.iteration - 1)
@paddle.no_grad()
def valid(self):
self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if attention_loss:
valid_losses['val_att_loss'].append(float(attention_loss))
if ctc_loss:
valid_losses['val_ctc_loss'].append(float(ctc_loss))
if (i + 1) % self.config.training.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalars(
'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
def setup_dataloader(self):
config = self.config.clone()
# train/valid dataset, return token ids
self.train_loader = BatchDataLoader(
json_file=config.data.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.collator.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.augmentation_config,
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
# test dataset, return raw text
self.test_loader = BatchDataLoader(
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
# model
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
# lr
scheduler_conf = config.scheduler_conf
scheduler_args = {
"learning_rate": scheduler_conf.lr,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.encoder_conf.output_size,
"verbose": False,
}
lr_scheduler = LRSchedulerFactory.from_args(config.scheduler,
scheduler_args)
# opt
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
optim_conf = config.optim_conf
return {
"grad_clip": optim_conf.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler,
"parameters": parameters,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(config.optim, optimzer_args)
self.model = model
self.lr_scheduler = lr_scheduler
self.optimizer = optimizer
logger.info("Setup model/optimizer/lr_scheduler!")
class U2Tester(U2Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# decoding config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=10, # Beam search width.
batch_size=16, # decoding batch size
ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args):
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("One example error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
assert self.args.result_file
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
rtf = num_time / (num_frames * stride_ms)
logger.info(
"RTF: %f, Error rate [%s] (%d/?) = %f" %
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
rtf = num_time / (num_frames * stride_ms)
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "RTF: {}, ".format(rtf)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
# test meta results
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
"decode_method":
self.config.decoding.decoding_method,
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collate.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
Returns:
nn.Layer: inference model
List[paddle.static.InputSpec]: input spec.
"""
from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.feat_dim
input_spec = [
paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[1],
dtype='int64'), # audio_length, [B]
]
return infer_model, input_spec
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
infer_model.eval()
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Text
import numpy as np
from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest
@ -25,6 +31,18 @@ __all__ = ["BatchDataLoader"]
logger = Log(__name__).getlog()
def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
mode: Text="asr",
iaxis=0,
oaxis=0):
if mode == 'asr':
feat_dim = data_json[0]['input'][oaxis]['shape'][1]
vocab_size = data_json[0]['output'][oaxis]['shape'][1]
else:
raise ValueError(f"{mode} mode not support!")
return feat_dim, vocab_size
class BatchDataLoader():
def __init__(self,
json_file: str,
@ -62,6 +80,8 @@ class BatchDataLoader():
# read json data
self.data_json = read_manifest(json_file)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')
# make minibatch list (variable length)
self.minibaches = make_batchset(
@ -106,7 +126,7 @@ class BatchDataLoader():
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not use_sortagrad if train_mode else False,
shuffle=not self.use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )

@ -66,8 +66,9 @@ class LoadInputsAndTargets():
raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None:
self.preprocessing = AugmentationPipeline(preprocess_conf)
logging.warning(
with open(preprocess_conf, 'r') as fin:
self.preprocessing = AugmentationPipeline(fin.read())
logger.warning(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format(
self.preprocessing))
@ -197,7 +198,7 @@ class LoadInputsAndTargets():
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs[0]):
logging.warning(
logger.warning(
"Target sequences include empty tokenid (batch {} -> {}).".
format(len(xs[0]), len(nonzero_sorted_idx)))

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False

@ -612,32 +612,32 @@ class U2BaseModel(nn.Layer):
best_index = i
return hyps[best_index][0]
#@jit.export
#@jit.to_static
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
#@jit.export
#@jit.to_static
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
#@jit.export
#@jit.to_static
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
#@jit.export
#@jit.to_static
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
@jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
@ -667,7 +667,7 @@ class U2BaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
# @jit.export([
# @jit.to_static([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
@ -680,7 +680,7 @@ class U2BaseModel(nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.export
@jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,

@ -69,7 +69,7 @@ class ConvGLUBlock(nn.Layer):
dim=0)
self.dropout_residual = nn.Dropout(p=dropout)
self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
self.pad_left = nn.Pad2d((0, 0, kernel_size - 1, 0), 0)
layers = OrderedDict()
if bottlececk_dim == 0:

@ -15,6 +15,7 @@ from typing import Any
from typing import Dict
from typing import Text
import paddle
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
@ -43,6 +44,40 @@ def register_optimizer(cls):
return cls
@register_optimizer
class Noam(paddle.optimizer.Adam):
"""Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """
def __init__(self,
learning_rate=0,
beta1=0.9,
beta2=0.98,
epsilon=1e-9,
parameters=None,
weight_decay=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
super().__init__(
learning_rate=learning_rate,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
parameters=parameters,
weight_decay=weight_decay,
grad_clip=grad_clip,
lazy_mode=lazy_mode,
multi_precision=multi_precision,
name=name)
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"learning_rate: {self._learning_rate}, "
echo += f"(beta1: {self._beta1} beta2: {self._beta2}), "
echo += f"epsilon: {self._epsilon}"
def dynamic_import_optimizer(module):
"""Import Optimizer class dynamically.
@ -69,15 +104,18 @@ class OptimizerFactory():
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
module_class = dynamic_import_optimizer(name.lower())
if weight_decay:
logger.info(f'WeightDecay: {weight_decay}')
logger.info(f'<WeightDecay - {weight_decay}>')
if grad_clip:
logger.info(f'GradClip: {grad_clip}')
logger.info(
f"Optimizer: {module_class.__name__} {args['learning_rate']}")
logger.info(f'<GradClip - {grad_clip}>')
module_class = dynamic_import_optimizer(name.lower())
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
return instance_class(module_class, args)
opt = instance_class(module_class, args)
if "__repr__" in vars(opt):
logger.info(f"{opt}")
else:
logger.info(
f"<Optimizer {module_class.__module__}.{module_class.__name__}> LR: {args['learning_rate']}"
)
return opt

@ -41,22 +41,6 @@ def register_scheduler(cls):
return cls
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
@register_scheduler
class WarmupLR(LRScheduler):
"""The WarmupLR scheduler
@ -102,6 +86,41 @@ class WarmupLR(LRScheduler):
self.step(epoch=step)
@register_scheduler
class ConstantLR(LRScheduler):
"""
Args:
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, verbose=False):
super().__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
class LRSchedulerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):

@ -19,7 +19,7 @@ collator:
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
@ -38,7 +38,7 @@ collator:
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: transformer
@ -74,20 +74,20 @@ model:
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
optim: adam
optim_conf:
global_grad_clip: 5.0
weight_decay: 1.0e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
lr: 0.004
warmup_steps: 25000
lr_decay: 1.0
decoding:
batch_size: 64

@ -20,6 +20,7 @@ echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--model-name u2_kaldi \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \

@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
MODEL=u2_kaldi
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin

@ -7,4 +7,3 @@
* https://github.com/NVIDIA/FasterTransformer.git
* https://github.com/idiap/fast-transformers

Loading…
Cancel
Save