Merge pull request #879 from PaddlePaddle/debug

fix ctc loss not convergence
pull/880/head
Jackwaterveg 4 years ago committed by GitHub
commit ad6d09d622
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,7 @@
All tested under:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.2.0rc
* paddlepaddle==2.1.2
Please see [install](docs/src/install.md).

@ -353,3 +353,31 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss

@ -28,7 +28,6 @@
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
const std::string kSPACE = "<space>";
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,

@ -15,10 +15,12 @@
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <string>
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const std::string kSPACE = "<space>";
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();

@ -26,7 +26,6 @@
#include "decoder_utils.h"
using namespace lm::ngram;
const std::string kSPACE = "<space>";
Scorer::Scorer(double alpha,
double beta,

@ -33,10 +33,9 @@ if __name__ == "__main__":
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
parser.add_argument(
"--model_type", type=str, default='offline', help="offline/online")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args)

@ -30,14 +30,13 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html

@ -36,11 +36,10 @@ if __name__ == "__main__":
#load jit model from
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument("--model_type")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html

@ -177,15 +177,14 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
parser.add_argument("--audio_file")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
parser.add_argument("--audio_file", type=str, help='audio file path')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
if not os.path.isfile(args.audio_file):
print("Please input the audio file path")
sys.exit(-1)

@ -35,10 +35,9 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args, globals())

@ -16,7 +16,6 @@ import os
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
@ -87,7 +86,8 @@ class DeepSpeech2Trainer(Trainer):
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
@ -385,13 +385,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(msg)
self.autolog.report()
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
@paddle.no_grad()
def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
@ -408,40 +402,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
@ -645,38 +605,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_lens = output_lens_handle.copy_to_cpu()
return output_probs, output_lens
def run_test(self):
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(self.args.export_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
super().setup_model()
speedyspeech_config = inference.Config(

@ -14,12 +14,10 @@
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
@ -44,8 +42,6 @@ from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
@ -106,7 +102,8 @@ class U2Trainer(Trainer):
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if self.parallel else nullcontext
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
@ -550,78 +547,12 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
@ -643,6 +574,7 @@ class U2Tester(U2Trainer):
]
return infer_model, input_spec
@paddle.no_grad()
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
@ -650,37 +582,3 @@ class U2Tester(U2Trainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -14,11 +14,9 @@
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
@ -39,8 +37,6 @@ from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
@ -105,7 +101,8 @@ class U2Trainer(Trainer):
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
@ -524,78 +521,12 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
@ -617,6 +548,7 @@ class U2Tester(U2Trainer):
]
return infer_model, input_spec
@paddle.no_grad()
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
@ -625,43 +557,11 @@ class U2Tester(U2Trainer):
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup_dict(self):
# load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
super().setup()
self.setup_dict()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -14,11 +14,9 @@
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
@ -42,8 +40,6 @@ from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
@ -110,7 +106,8 @@ class U2STTrainer(Trainer):
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
@ -544,78 +541,12 @@ class U2STTester(U2STTrainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
@ -637,6 +568,7 @@ class U2STTester(U2STTrainer):
]
return infer_model, input_spec
@paddle.no_grad()
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
@ -644,37 +576,3 @@ class U2STTester(U2STTrainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -15,6 +15,7 @@
import json
from collections.abc import Sequence
from inspect import signature
from pprint import pformat
import numpy as np
@ -22,10 +23,10 @@ from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.log import Log
__all__ = ["AugmentationPipeline"]
logger = Log(__name__).getlog()
__all__ = ["AugmentationPipeline"]
import_alias = dict(
volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
@ -111,6 +112,8 @@ class AugmentationPipeline():
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature')
logger.info(
f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
@ -197,8 +200,10 @@ class AugmentationPipeline():
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
else:
elif aug_type == 'all':
aug_confs = all_confs
else:
raise ValueError(f"Not support: {aug_type}")
augmentors = [
self._get_augmentor(config["type"], config["params"])

@ -29,10 +29,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def __init__(self,
@ -61,7 +61,7 @@ class SpecAugmentor(AugmentorBase):
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
super().__init__()
@ -133,7 +133,7 @@ class SpecAugmentor(AugmentorBase):
return self._time_mask
def __repr__(self):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
return f"specaug: F-{self.F}, T-{self.T}, F-n-{self.n_freq_masks}, T-n-{self.n_time_masks}"
def time_warp(self, x, mode='PIL'):
"""time warp for spec augment

@ -51,12 +51,14 @@ class SpeechFeaturizer():
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self.feature_size = self.audio_feature.feature_size
self.text_feature = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
maskctc=maskctc)
self.vocab_size = self.text_feature.vocab_size
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.

@ -12,12 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat
import sentencepiece as spm
from ..utility import BLANK
from ..utility import EOS
from ..utility import load_dict
from ..utility import MASKCTC
from ..utility import SOS
from ..utility import SPACE
from ..utility import UNK
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["TextFeaturizer"]
@ -76,7 +84,7 @@ class TextFeaturizer():
"""Convert text string to a list of token indices.
Args:
text (str): Text.
text (str): Text to process.
Returns:
List[int]: List of token indices.
@ -199,13 +207,24 @@ class TextFeaturizer():
"""Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
logger.info(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id

@ -49,7 +49,11 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
with open(dict_path, "r") as f:
dictionary = f.readlines()
char_list = [entry.strip().split(" ")[0] for entry in dictionary]
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:

@ -41,6 +41,13 @@ def conv_output_size(I, F, P, S):
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
@ -106,9 +113,10 @@ class ConvBn(nn.Layer):
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -218,14 +218,18 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id)
model = cls(
#feat_size=dataloader.collate_fn.feature_size,
feat_size=dataloader.dataset.feature_size,
#dict_size=dataloader.collate_fn.vocab_size,
dict_size=dataloader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -244,36 +248,22 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, audio, audio_len):
"""export model function

@ -308,8 +308,8 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -255,22 +255,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
))
ctc_grad_norm_type='instance', ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
def __init__(
self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0,
ctc_grad_norm_type='instance', ):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type='instance')
grad_norm_type=ctc_grad_norm_type)
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
@ -348,16 +350,18 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
model = cls(
feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -376,42 +380,24 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id)
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru,
blank_id=blank_id)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):

@ -164,7 +164,10 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att = None
@ -319,7 +322,8 @@ class U2BaseModel(nn.Layer):
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
@ -405,7 +409,9 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.shape[1]
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)

@ -165,7 +165,10 @@ class U2STBaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch
start = time.time()
@ -362,7 +365,8 @@ class U2STBaseModel(nn.Layer):
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step

@ -1,170 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from paddle.nn import functional as F
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type, relu|brelu
:type act: string
:return: Batch norm layer after convolution layer.
:rtype: Variable
"""
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
padding, act):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(padding) == 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2D(
num_channels_in,
num_channels_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_attr=None,
bias_attr=False,
data_format='NCHW')
self.bn = nn.BatchNorm2D(
num_channels_out,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=(41, 11), #[D, T]
stride=(2, 3),
padding=(20, 5),
act='brelu')
out_channel = 32
convs = [
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
]
self.conv_stack = nn.LayerList(convs)
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len

@ -124,7 +124,9 @@ class TransformerDecoder(nn.Layer):
# m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
# TODO(Hui Zhang): not support & for tensor
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt)
for layer in self.decoders:
@ -135,7 +137,9 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens
def forward_one_step(

@ -162,7 +162,8 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
mask_pad = ~masks
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import paddle
from paddle import nn
from paddle.nn import functional as F
@ -32,18 +34,19 @@ class CTCLoss(nn.Layer):
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
logger.info(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
# instance for norm_by_times
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
if grad_norm_type is None:
# no grad norm
pass
elif grad_norm_type == 'instance':
self.norm_by_times = True
elif grad_norm_type == 'batch':
self.norm_by_batchsize = True
@ -51,6 +54,22 @@ class CTCLoss(nn.Layer):
self.norm_by_total_logits_len = True
else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
self.kwargs = {
"norm_by_times": self.norm_by_times,
"norm_by_batchsize": self.norm_by_batchsize,
"norm_by_total_logits_len": self.norm_by_total_logits_len,
}
# Derive only the args which the func has
try:
param = inspect.signature(self.loss.forward).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param}
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
#self.loss_fn = partial(self.loss.forward, **_kwargs)
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
@ -70,14 +89,8 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(
logits,
ys_pad,
hlens,
ys_lens,
norm_by_times=self.norm_by_times,
norm_by_batchsize=self.norm_by_batchsize,
norm_by_total_logits_len=self.norm_by_total_logits_len)
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# Batch-size average
loss = loss / B
@ -118,8 +131,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
True, normalize loss by sequence length;
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
@ -136,7 +149,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
@ -152,7 +165,7 @@ class LabelSmoothingLoss(nn.Layer):
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = (target == self.padding_idx) # (B,)
ignore = target == self.padding_idx # (B,)
#TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target = target.masked_fill(ignore, 0) # avoid -1 index
@ -163,8 +176,10 @@ class LabelSmoothingLoss(nn.Layer):
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
total = len(target) - int(ignore.sum())
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
denom = total if self.normalize_length else B
#TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum()
#numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
return numer / denom

@ -69,7 +69,8 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
#return ~make_pad_mask(lengths)
return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor:
@ -91,7 +92,12 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]]
"""
ret = paddle.ones([size, size], dtype=paddle.bool)
return paddle.tril(ret)
#TODO(Hui Zhang): tril not support bool
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask(
@ -180,13 +186,15 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
else:
chunk_masks = masks
return chunk_masks

@ -1,314 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.LayerList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
return x, x_len

@ -14,6 +14,7 @@
import sys
import time
from collections import OrderedDict
from contextlib import contextmanager
from pathlib import Path
import paddle
@ -27,6 +28,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import profiler
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
@ -102,13 +104,28 @@ class Trainer():
self.iteration = 0
self.epoch = 0
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
logger.info(f"Rank: {self.rank}/{dist.get_world_size()}")
# print deps version
all_version()
logger.info(f"Rank: {self.rank}/{self.world_size}")
# set device
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
# set random seed if needed
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
# profiler and benchmark options
if self.args.benchmark_batch_size:
with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size
@ -116,17 +133,18 @@ class Trainer():
logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@contextmanager
def eval(self):
self._train = False
yield
self._train = True
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
@ -161,9 +179,9 @@ class Trainer():
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
})
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
@ -181,8 +199,8 @@ class Trainer():
if infos:
# just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] + 1
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
@ -300,37 +318,74 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
with Timer("Training Done: {}"):
try:
try:
with Timer("Training Done: {}"):
self.train()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.resume_or_scratch()
self.test()
except KeyboardInterrupt:
exit(-1)
def run_export(self):
"""Do Model Export"""
try:
with Timer("Export Done: {}"):
with self.eval():
self.export()
except KeyboardInterrupt:
exit(-1)
def run_align(self):
"""Do CTC alignment"""
try:
with Timer("Align Done: {}"):
with self.eval():
self.resume_or_scratch()
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
if self.args.output:
output_dir = Path(self.args.output).expanduser()
elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.log_dir = output_dir / "log"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = checkpoint_dir
self.test_dir = output_dir / "test"
self.test_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
self.decode_dir = output_dir / "decode"
self.decode_dir.mkdir(parents=True, exist_ok=True)
self.export_dir = output_dir / "export"
self.export_dir.mkdir(parents=True, exist_ok=True)
self.visual_dir = output_dir / "visual"
self.visual_dir.mkdir(parents=True, exist_ok=True)
self.config_dir = output_dir / "conf"
self.config_dir.mkdir(parents=True, exist_ok=True)
@mp_tools.rank_zero_only
def destory(self):
@ -352,7 +407,7 @@ class Trainer():
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
visualizer = SummaryWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer
@mp_tools.rank_zero_only
@ -362,7 +417,14 @@ class Trainer():
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"])
config_file.rename(target_path)
with open(config_file, 'wt') as f:
print(self.config, file=f)
def train_batch(self):
@ -376,6 +438,24 @@ class Trainer():
"""
raise NotImplementedError("valid should be implemented.")
@paddle.no_grad()
def test(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("test should be implemented.")
@paddle.no_grad()
def export(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("export should be implemented.")
@paddle.no_grad()
def align(self):
"""The align. A subclass should implement this method in Tester.
"""
raise NotImplementedError("align should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.

@ -39,13 +39,13 @@ class Checkpoint():
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
def save_parameters(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
Args:

@ -16,6 +16,8 @@ from typing import List
import numpy as np
import paddle
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
@ -87,14 +89,16 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
(ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_path = (paddle.zeros(
(ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.shape[0]): # T
for s in range(len(y_insert_blank)): # 2L+1
@ -110,9 +114,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
@ -130,3 +136,85 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment
# ctc_align(
# self.model,
# self.align_loader,
# self.config.decoding.batch_size,
# self.align_loader.collate_fn.stride_ms,
# self.align_loader.collate_fn.vocab_list,
# self.args.result_file,
# )
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
result_file):
"""ctc alignment.
Args:
model (nn.Layer): U2 Model.
dataloader (io.DataLoader): dataloader.
batch_size (int): decoding batchsize.
stride_ms (int): audio feature stride in ms unit.
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
result_file (str): alignment output file, e.g. xxx.align.
"""
if batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
assert result_file and result_file.endswith('.align')
model.eval()
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
with open(result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(dataloader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))

@ -19,6 +19,8 @@ import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
editdistance.eval("a", "b")
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
@ -90,6 +92,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
hyp_words = list(filter(None, hypothesis.split(delimiter)))
edit_distance = _levenshtein_distance(ref_words, hyp_words)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(ref_words, hyp_words)
return float(edit_distance), len(ref_words)
@ -121,6 +124,7 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
edit_distance = _levenshtein_distance(reference, hypothesis)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(reference, hypothesis)
return float(edit_distance), len(reference)

@ -120,14 +120,15 @@ class Autolog:
model_precision="fp32"):
import auto_log
pid = os.getpid()
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
if os.environ.get('CUDA_VISIBLE_DEVICES', None):
gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
infer_config = inference.Config()
infer_config.enable_use_gpu(100, gpu_id)
else:
gpu_id = None
infer_config = inference.Config()
autolog = auto_log.AutoLogger(
self.autolog = auto_log.AutoLogger(
model_name=model_name,
model_precision=model_precision,
batch_size=batch_size,
@ -139,7 +140,6 @@ class Autolog:
gpu_ids=gpu_id,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
self.autolog = autolog
def getlog(self):
return self.autolog

@ -183,7 +183,13 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)

@ -16,17 +16,35 @@ import distutils.util
import math
import os
import random
import sys
from contextlib import contextmanager
from pprint import pformat
from typing import List
import numpy as np
import paddle
import soundfile
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"UpdateConfig", "seed_all", 'print_arguments', 'add_arguments', "log_add"
"all_version", "UpdateConfig", "seed_all", 'print_arguments',
'add_arguments', "log_add"
]
def all_version():
vers = {
"python": sys.version,
"paddle": paddle.__version__,
"paddle_commit": paddle.version.commit,
"soundfile": soundfile.__version__,
}
logger.info(f"Deps Module Version:{pformat(vers.items())}")
@contextmanager
def UpdateConfig(config):
"""Update yacs config"""
@ -35,7 +53,7 @@ def UpdateConfig(config):
config.freeze()
def seed_all(seed: int=210329):
def seed_all(seed: int=20210329):
"""freeze random generator seed."""
np.random.seed(seed)
random.seed(seed)
@ -61,7 +79,7 @@ def print_arguments(args, info=None):
if info:
filename = info["__file__"]
filename = os.path.basename(filename)
print(f"----------- {filename} Configuration Arguments -----------")
print(f"----------- {filename} Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("-----------------------------------------------------------")

@ -1,6 +1,6 @@
export MAIN_ROOT=${PWD}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:/usr/local/bin:${PATH}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}:/usr/local/bin/
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C

@ -4,6 +4,7 @@ export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -4,6 +4,7 @@ export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -4,6 +4,7 @@ export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -31,14 +31,13 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -3,6 +3,7 @@ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}

@ -1,4 +1,5 @@
coverage
editdistance
gpustat
jsonlines
kaldiio
@ -19,4 +20,3 @@ tqdm
typeguard
visualdl==2.2.0
yacs
editdistance

@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase):
def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths)
res2 = ~make_pad_mask(self.lengths)
res2 = make_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist())
def test_make_pad_mask(self):
res = make_pad_mask(self.lengths)
res1 = ~make_non_pad_mask(self.lengths)
res1 = make_non_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())

Loading…
Cancel
Save