Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into release_model

pull/838/head
huangyuxin 3 years ago
commit 7e96942c58

@ -1,5 +1,3 @@
[中文版](README_cn.md)
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
@ -11,7 +9,7 @@
## Features
See [feature list](doc/src/feature_list.md) for more information.
See [feature list](docs/src/feature_list.md) for more information.
## Setup
@ -20,20 +18,20 @@ All tested under:
* python>=3.7
* paddlepaddle>=2.2.0rc
Please see [install](doc/src/install.md).
Please see [install](docs/src/install.md).
## Getting Started
Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
Please see [Getting Started](docs/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md).
## More Information
* [Data Prepration](doc/src/data_preparation.md)
* [Data Augmentation](doc/src/augmentation.md)
* [Ngram LM](doc/src/ngram_lm.md)
* [Benchmark](doc/src/benchmark.md)
* [Relased Model](doc/src/released_model.md)
* [Data Prepration](docs/src/data_preparation.md)
* [Data Augmentation](docs/src/augmentation.md)
* [Ngram LM](docs/src/ngram_lm.md)
* [Benchmark](docs/src/benchmark.md)
* [Relased Model](docs/src/released_model.md)
## Questions and Help
@ -47,4 +45,4 @@ DeepSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Acknowledgement
We depends on many open source repos. See [References](doc/src/reference.md) for more information.
We depends on many open source repos. See [References](docs/src/reference.md) for more information.

@ -1,49 +0,0 @@
[English](README.md)
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目,
我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。
## 特性
参看 [特性列表](doc/src/feature_list.md)。
## 安装
在以下环境测试验证过:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.2.0rc
参看 [安装](doc/src/install.md)。
## 开始
请查看 [开始](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/s0/README.md)。
## 更多信息
* [数据处理](doc/src/data_preparation.md)
* [数据增强](doc/src/augmentation.md)
* [语言模型](doc/src/ngram_lm.md)
* [Benchmark](doc/src/benchmark.md)
* [Relased Model](doc/src/released_model.md)
## 问题和帮助
欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。
## License
DeepSpeech 遵循[Apache-2.0开源协议](./LICENSE)。
## 感谢
开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)。

@ -1,191 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam search parameters tuning for DeepSpeech2 model."""
import functools
import sys
import numpy as np
from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import error_rate
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
def tune(config, args):
"""Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0:
raise ValueError("num_alphas must be non-negative!")
if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!")
config.defrost()
config.data.manfiest = config.data.dev_manifest
config.data.augmentation_config = ""
config.data.keep_transcription_text = True
dev_dataset = ManifestDataset.from_config(config)
valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path)
model.eval()
# decoders only accept string encoded in utf-8
vocab_list = valid_loader.dataset.vocab_list
errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors
# create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
params_grid = [(alpha, beta) for alpha in cand_alphas
for beta in cand_betas]
err_sum = [0.0 for i in range(len(params_grid))]
err_ave = [0.0 for i in range(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0
# initialize external scorer
model.decoder.init_decode(args.alpha_from, args.beta_from,
config.decoding.lang_model_path, vocab_list,
config.decoding.decoding_method)
## incremental tuning parameters over multiple batches
print("start tuning ...")
for infer_data in valid_loader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break
def ordid2token(texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
audio, audio_len, text, text_len = infer_data
target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0]
# model infer
eouts, eouts_len = model.encoder(audio, audio_len)
probs = model.decoder.softmax(eouts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):
print(f"tuneing: alpha={alpha} beta={beta}")
result_transcripts = model.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list,
config.decoding.decoding_method,
config.decoding.lang_model_path, alpha, beta,
config.decoding.beam_size, config.decoding.cutoff_prob,
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
err_sum[index] += errors
# accumulate the length of references of every batchπ
# in the first iteration
if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref
err_ave[index] = err_sum[index] / len_refs
if index % 2 == 0:
sys.stdout.write('.')
sys.stdout.flush()
print("tuneing: one grid done!")
# output on-line tuning result at the end of current batch
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" %
(cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1],
config.decoding.error_rate_type, err_ave_min))
cur_batch += 1
# output WER/CER at every (alpha, beta)
print("\nFinal %s:\n" % config.decoding.error_rate_type)
for index in range(len(params_grid)):
print("(alpha, beta) = (%s, %s), [%s] = %f" %
("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
config.decoding.error_rate_type, err_ave[index]))
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" %
(cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1]))
print("finish tuning")
def main(config, args):
tune(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_batches', int, -1, "# of batches tuning on. "
"Default -1, on whole dev set.")
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
add_arg('batch_size', int, 256, "# of samples per batch.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.data.batch_size = args.batch_size
config.decoding.beam_size = args.beam_size
config.decoding.num_proc_bsearch = args.num_proc_bsearch
config.decoding.cutoff_prob = args.cutoff_prob
config.decoding.cutoff_top_n = args.cutoff_top_n
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -35,12 +35,14 @@ from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.reporter import report
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -66,7 +68,9 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
batch_size = self.config.collator.batch_size
accum_grad = self.config.training.accum_grad
start = time.time()
# forward
@ -77,7 +81,7 @@ class DeepSpeech2Trainer(Trainer):
}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
if (batch_index + 1) % accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
@ -92,19 +96,18 @@ class DeepSpeech2Trainer(Trainer):
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
if (batch_index + 1) % accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
for k, v in losses_np.items():
report(k, v)
report("batch_size", batch_size)
report("accum", accum_grad)
report("step_cost", iteration_time)
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
@ -147,10 +150,9 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self):
config = self.config.clone()
config.defrost()
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()
with UpdateConfig(config):
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)

@ -18,11 +18,11 @@ import os
from paddle import distributed as dist
from deepspeech.exps.u2.config import get_cfg_defaults
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
from deepspeech.exps.u2.model import U2Trainer as Trainer
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
from deepspeech.exps.u2.model import U2Trainer as Trainer
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args):

@ -17,6 +17,7 @@ import os
import sys
import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
@ -33,6 +34,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
@ -43,6 +46,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -100,7 +104,8 @@ class U2Trainer(Trainer):
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if self.parallel else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
@ -118,14 +123,13 @@ class U2Trainer(Trainer):
iteration_time = time.time() - start
if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
for k, v in losses_np.items():
report(k, v)
report("batch_size", self.config.collator.batch_size)
report("accum", train_conf.accum_grad)
report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
@ -182,9 +186,10 @@ class U2Trainer(Trainer):
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.iteration)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.iteration)
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
@ -196,14 +201,31 @@ class U2Trainer(Trainer):
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
msg = "Train:"
observation = OrderedDict()
with ObsScope(observation):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('step/total',
(batch_index + 1) / len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips[sent./sec]'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += ","
if (batch_index + 1
) % self.config.training.log_interval == 0:
logger.info(msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
@ -312,10 +334,11 @@ class U2Trainer(Trainer):
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:
@ -558,7 +581,7 @@ class U2Tester(U2Trainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
@ -566,26 +589,25 @@ class U2Tester(U2Trainer):
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
@ -593,7 +615,7 @@ class U2Tester(U2Trainer):
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()

@ -32,6 +32,7 @@ from deepspeech.training.trainer import Trainer
from deepspeech.training.updaters.trainer import Trainer as NewTrainer
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -121,10 +122,10 @@ class U2Trainer(Trainer):
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:

@ -41,6 +41,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -205,6 +206,7 @@ class U2Trainer(Trainer):
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
data_start_time = time.time()
except Exception as e:
logger.error(e)
@ -318,10 +320,10 @@ class U2Trainer(Trainer):
# model
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
model_conf.freeze()
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
@ -544,9 +546,8 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.config.collater.stride_ms
token_dict = self.args.char_list
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
@ -555,7 +556,7 @@ class U2Tester(U2Trainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
@ -563,26 +564,25 @@ class U2Tester(U2Trainer):
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
@ -590,7 +590,7 @@ class U2Tester(U2Trainer):
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()

@ -47,6 +47,7 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -222,6 +223,7 @@ class U2STTrainer(Trainer):
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
data_start_time = time.time()
except Exception as e:
logger.error(e)
@ -344,10 +346,10 @@ class U2STTrainer(Trainer):
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model = U2STModel.from_config(model_conf)
if self.parallel:
@ -586,7 +588,7 @@ class U2STTester(U2STTrainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
@ -594,26 +596,25 @@ class U2STTester(U2STTrainer):
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
@ -621,7 +622,7 @@ class U2STTester(U2STTrainer):
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()

@ -76,19 +76,19 @@ class ManifestDataset(Dataset):
Args:
manifest_path (str): manifest josn file path
max_input_len ([type], optional): maximum output seq length,
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio.
Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio.
Defaults to 0.05.
"""
super().__init__()

@ -48,6 +48,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add
from deepspeech.utils.utility import UpdateConfig
__all__ = ["U2Model", "U2InferModel"]
@ -297,8 +298,8 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
@ -403,7 +404,7 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
@ -454,7 +455,7 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
@ -582,7 +583,7 @@ class U2BaseModel(nn.Layer):
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
@ -689,13 +690,13 @@ class U2BaseModel(nn.Layer):
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
@ -750,7 +751,7 @@ class U2BaseModel(nn.Layer):
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.fatal(
@ -778,7 +779,7 @@ class U2BaseModel(nn.Layer):
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.size(0) == 1
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(
feats,
feats_lengths,
@ -788,7 +789,7 @@ class U2BaseModel(nn.Layer):
simulate_streaming=simulate_streaming)
hyps = [hyp]
elif decoding_method == 'attention_rescoring':
assert feats.size(0) == 1
assert feats.shape[0] == 1
hyp = self.attention_rescoring(
feats,
feats_lengths,
@ -903,10 +904,10 @@ class U2Model(U2BaseModel):
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
model = cls.from_config(config)
if checkpoint_path:

@ -42,6 +42,7 @@ from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"]
@ -339,8 +340,8 @@ class U2STBaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
@ -495,13 +496,13 @@ class U2STBaseModel(nn.Layer):
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
@ -556,7 +557,7 @@ class U2STBaseModel(nn.Layer):
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
batch_size = feats.shape[0]
if decoding_method == 'fullsentence':
hyps = self.translate(
@ -686,10 +687,10 @@ class U2STModel(U2STBaseModel):
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
model = cls.from_config(config)
if checkpoint_path:

@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.size(0)
n_batch = value.shape[0]
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype)
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
ones = paddle.ones((x.size(2), x.size(3)))
x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x
@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

@ -122,7 +122,7 @@ class TransformerDecoder(nn.Layer):
# tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m

@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = x.shape[1]
assert offset + x.size(1) < self.max_len
assert offset + x.shape[1] < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb
@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + x.size(1) < self.max_len
assert offset + x.shape[1] < self.max_len
x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]

@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
mask_pad = ~masks
@ -206,11 +206,11 @@ class BaseEncoder(nn.Layer):
chunk computation
List[paddle.Tensor]: conformer cnn cache
"""
assert xs.size(0) == 1 # batch size must be one
assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None:
@ -220,25 +220,25 @@ class BaseEncoder(nn.Layer):
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None:
cache_size = subsampling_cache.size(1) #T
cache_size = subsampling_cache.shape[1] #T
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.size(1))
offset=offset - cache_size, size=xs.shape[1])
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = xs.size(1)
next_cache_start = xs.shape[1]
else:
next_cache_start = xs.size(1) - required_cache_size
next_cache_start = xs.shape[1] - required_cache_size
r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
r_conformer_cnn_cache = []
@ -302,7 +302,7 @@ class BaseEncoder(nn.Layer):
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None
@ -318,10 +318,10 @@ class BaseEncoder(nn.Layer):
chunk_xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
outputs.append(y)
offset += y.size(1)
offset += y.shape[1]
ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks

@ -43,28 +43,57 @@ def default_argument_parser():
"""
parser = argparse.ArgumentParser()
# yapf: disable
# data and output
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
train_group = parser.add_argument_group(
title='Train Options', description=None)
train_group.add_argument(
"--seed",
type=int,
default=None,
help="seed to use for paddle, np and random. None or 0 for random, else set seed."
)
train_group.add_argument(
"--device",
type=str,
default='gpu',
choices=["cpu", "gpu"],
help="device cpu and gpu are supported.")
train_group.add_argument(
"--nprocs",
type=int,
default=1,
help="number of parallel processes. 0 for cpu.")
train_group.add_argument(
"--config", metavar="CONFIG_FILE", help="config file.")
train_group.add_argument(
"--output", metavar="CKPT_DIR", help="path to save checkpoint.")
train_group.add_argument(
"--checkpoint_path", type=str, help="path to load checkpoint")
train_group.add_argument(
"--opts",
type=str,
default=[],
nargs='+',
help="overwrite --config file, passing in LIST[KEY VALUE] pairs")
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")
# load from saved checkpoint
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
# running
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.")
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
# overwrite extra config and default config
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--opts", type=str, default=[], nargs='+',
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable
profile_group = parser.add_argument_group(
title='Benchmark Options', description=None)
profile_group.add_argument(
'--profiler-options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
profile_group.add_argument(
'--benchmark-batch-size',
type=int,
default=None,
help='batch size for benchmark.')
profile_group.add_argument(
'--benchmark-max-step',
type=int,
default=None,
help='max iteration for benchmark.')
return parser

@ -20,8 +20,8 @@ from paddle.nn import Layer
from . import extension
from ..reporter import DictSummary
from ..reporter import ObsScope
from ..reporter import report
from ..reporter import scope
from ..timer import Timer
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
@ -78,7 +78,7 @@ class StandardEvaluator(extension.Extension):
summary = DictSummary()
for batch in self.dataloader:
observation = {}
with scope(observation):
with ObsScope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_sync(self.evaluate_core(batch))

@ -101,7 +101,7 @@ class Snapshot(extension.Extension):
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz"
path = self.checkpoint_dir / f"{num}.np"
# add the new one
trainer.updater.save(path)

@ -19,7 +19,7 @@ OBSERVATIONS = None
@contextlib.contextmanager
def scope(observations):
def ObsScope(observations):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global OBSERVATIONS

@ -11,18 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
from collections import OrderedDict
from pathlib import Path
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.utils import mp_tools
from deepspeech.utils import profiler
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
__all__ = ["Trainer"]
@ -95,11 +101,21 @@ class Trainer():
self.checkpoint_dir = None
self.iteration = 0
self.epoch = 0
self.rank = dist.get_rank()
logger.info(f"Rank: {self.rank}/{dist.get_world_size()}")
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
if self.args.benchmark_batch_size:
with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size
self.config.training.log_interval = 1
logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
def setup(self):
"""Setup the experiment.
"""
@ -183,13 +199,23 @@ class Trainer():
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
profiler.add_profiler_step(self.args.profiler_options)
logger.info(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
sys.exit(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
def train(self):
"""The training process control by epoch."""
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
@ -201,14 +227,29 @@ class Trainer():
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
msg = "Train:"
observation = OrderedDict()
with ObsScope(observation):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('step/total',
(batch_index + 1) / len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips[sent./sec]'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += ","
logger.info(msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)

@ -24,7 +24,7 @@ import tqdm
from deepspeech.training.extensions.extension import Extension
from deepspeech.training.extensions.extension import PRIORITY_READER
from deepspeech.training.reporter import scope
from deepspeech.training.reporter import ObsScope
from deepspeech.training.triggers import get_trigger
from deepspeech.training.triggers.limit_trigger import LimitTrigger
from deepspeech.training.updaters.updater import UpdaterBase
@ -144,7 +144,7 @@ class Trainer():
# you can use `report` freely in Updater.update()
# updating parameters and state
with scope(self.observation):
with ObsScope(self.observation):
update()
p.update()

@ -84,19 +84,19 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
(ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
(ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): # T
for t in range(1, ctc_probs.shape[0]): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
@ -110,13 +110,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
# TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
@ -124,11 +122,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
for t in range(ctc_probs.shape[0] - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
for t in range(0, ctc_probs.shape[0]):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment

@ -0,0 +1,119 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
if not options_str:
return
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
logger.info(f"Profiler: {options_str}")
logger.info(f"Profiler: {_profiler_options._options}")
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(_profiler_options['state'],
_profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1

@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences])
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
@ -91,7 +91,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
@ -139,7 +139,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
@ -165,8 +165,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))

@ -16,15 +16,27 @@ import distutils.util
import math
import os
import random
from contextlib import contextmanager
from typing import List
import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
__all__ = [
"UpdateConfig", "seed_all", 'print_arguments', 'add_arguments', "log_add"
]
@contextmanager
def UpdateConfig(config):
"""Update yacs config"""
config.defrost()
yield
config.freeze()
def seed_all(seed: int=210329):
"""freeze random generator seed."""
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 108 KiB

@ -1,16 +0,0 @@
# Benchmarks
## Acceleration with Multi-GPUs
We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars.
<img src="../images/multi_gpu_speedup.png" width=450>
| # of GPU | Acceleration Rate |
| -------- | --------------: |
| 1 | 1.00 X |
| 2 | 1.98 X |
| 4 | 3.73 X |
| 8 | 6.95 X |
`utils/profile.sh` provides such a demo profiling tool, you can change it as need.

Before

Width:  |  Height:  |  Size: 93 KiB

After

Width:  |  Height:  |  Size: 93 KiB

Before

Width:  |  Height:  |  Size: 93 KiB

After

Width:  |  Height:  |  Size: 93 KiB

@ -20,7 +20,7 @@ The arcitecture of the model is shown in Fig.1.
### Data Preparation
#### Vocabulary
For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \<blank\> and \<eos\>. The \<blank\> represents the blank label in CTC, the \<unk\> represents the unknown character and the <eos> represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \<blank\>, \<unk\> and \<eos\>. For both English and mandarin data, we set the default indexs that \<blank\>=0, \<unk\>=1 and \<eos\>= last index.
For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \<blank\> and \<eos\>. The \<blank\> represents the blank label in CTC, the \<unk\> represents the unknown character and the \<eos\> represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \<blank\>, \<unk\> and \<eos\>. For both English and mandarin data, we set the default indexs that \<blank\>=0, \<unk\>=1 and \<eos\>= last index.
```
# The code to build vocabulary
cd examples/aishell/s0
@ -65,17 +65,21 @@ python3 ../../../utils/compute_mean_std.py \
```
### Encoder
The Backbone is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature represention from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature represention are input into the stacked rnn layers. For rnn layers, LSTM cell and GRU cell are provided. Adding one fully connected (fc) layer after rnn layer is optional, if the number of rnn layers is less than 5, adding one fc layer after rnn layers is recommand.
The encoder is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature represention from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature represention are input into the stacked rnn layers. For the stacked rnn layers, LSTM cell and GRU cell are provided to use. Adding one fully connected (fc) layer after the stacked rnn layers is optional. If the number of stacked rnn layers is less than 5, adding one fc layer after stacked rnn layers is recommand.
The code of Encoder is in:
```
vi deepspeech/models/ds2_online/deepspeech2.py
```
### Decoder
To got the character possibilities of each frame, the feature represention of each frame output from the backbone are input into a projection layer which is implemented as a dense layer to do projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to make frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results.
The code of Encoder is in:
To got the character possibilities of each frame, the feature represention of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results.
The code of the decoder is in:
```
# The code of constructing the decoder in model
vi deepspeech/models/ds2_online/deepspeech2.py
# The code of CTC Decoder
vi deepspeech/modules/ctc.py
```
@ -119,7 +123,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
```
By using the command above, the training process can be started. There are 5 stages in run.sh, and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss.
By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss.
## Testing Process
Using the command below, you can test the deepspeech2 online model.
@ -152,7 +157,7 @@ After the training process, we use stage 3,4,5 for testing process. The stage 3
## Non-Streaming
The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used.
The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use.
The arcitecture of the model is shown in Fig.2.
<p align="center">
@ -162,9 +167,9 @@ The arcitecture of the model is shown in Fig.2.
For data preparation, decoder, the deepspeech2 offline model is same with the deepspeech2 online model.
For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model.
The code of encoder and decoder for deepspeech2 offline model is in:
The code of encoder and decoder for deepspeech2 offline model is in:
```
vi deepspeech/models/ds2/deepspeech2.py
```

@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin
## Prerequisites
- Python >= 3.7
- PaddlePaddle 2.0.0 or later (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
- PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
## Setup
## Setup (Important)
- Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`:
```bash
sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
```
The version of `swig` should >= 3.0
or, installing them via `yum`:

@ -1,5 +1,7 @@
# Reference
We refer these repos to build `model` and `engine`:
* [delta](https://github.com/Delta-ML/delta.git)
* [espnet](https://github.com/espnet/espnet.git)
* [kaldi](https://github.com/kaldi-asr/kaldi.git)

@ -19,6 +19,7 @@ fi
mkdir -p exp
# seed may break model convergence
seed=10086
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True

@ -1,37 +1,49 @@
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
echo "using ${device}..."
mkdir -p exp
seed=10086
if [ ${seed} != 0]; then
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
config_path=$1
ckpt_name=$2
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--seed ${seed} \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
if [ ${seed} != 0 ]; then
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi

@ -19,8 +19,9 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
if [ ${seed} != 0]; then
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi

@ -0,0 +1,58 @@
# [CC-CEDICT](https://cc-cedict.org/wiki/)
What is CC-CEDICT?
CC-CEDICT is a continuation of the CEDICT project.
The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary.
CEDICT was started by Paul Andrew Denisowski in October 1997.
For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort,
with users providing entries and corrections to the main file.
## Parse CC-CEDICT to Json format
1. Parse to Json
```
run.sh
```
2. Result
```
exp/
|-- cedict
`-- cedict.json
0 directories, 2 files
```
```
4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict
3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json
```
```
==> exp/cedict <==
# CC-CEDICT
# Community maintained free Chinese-English dictionary.
#
# Published by MDBG
#
# License:
# Creative Commons Attribution-ShareAlike 4.0 International License
# https://creativecommons.org/licenses/by-sa/4.0/
#
# Referenced works:
==> exp/cedict.json <==
{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"}
{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"}
{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"}
{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"}
{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"}
{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"}
{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"}
{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"}
{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"}
{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"}
```

@ -1,5 +0,0 @@
# Download Baker dataset
Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset.
Download URL https://test.data-baker.com/#/data/index/source.

@ -0,0 +1,3 @@
# G2P
* zh - Chinese G2P

@ -0,0 +1,93 @@
# G2P
* WS
jieba
* G2P
pypinyin
* Tone sandhi
simple
We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P.
The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`.
## Download Baker dataset
[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data',
because you will have to pass the `CATTCHA` from a browswe to download the dataset.
## RUN
```
. path.sh
./run.sh
```
## Result
```
exp/
|-- 000001-010000.txt
|-- ref.pinyin
|-- trans.jieba.pinyin
`-- trans.pinyin
0 directories, 4 files
```
```
4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt
01707896391c2de9b6fc4a39654be942 exp/ref.pinyin
43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin
8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin
```
```
==> exp/000001-010000.txt <==
000001 卡尔普#2陪外孙#1玩滑梯#4。
ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 假语村言#2别再#1拥抱我#4。
jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 宝马#1配挂#1跛骡鞍#3貂蝉#1怨枕#2董翁榻#4。
bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 邓小平#2与#1撒切尔#2会晤#4。
deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。
lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
==> exp/ref.pinyin <==
000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4
000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.jieba.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
==> exp/trans.pinyin <==
000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1
000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3
000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4
000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4
000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3
000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4
000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1
000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2
000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3
000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1
```

@ -1,4 +1,4 @@
export MAIN_ROOT=`realpath ${PWD}/../../`
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C

@ -6,16 +6,19 @@ stage=-1
stop_stage=100
exp_dir=exp
data_dir=data
data=data
source ${MAIN_ROOT}/utils/parse_options.sh || exit -1
mkdir -p ${exp_dir}
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then
test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in ${data}; exit -1; }
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then
echo "stage 0: Extracting Prosody Labeling"
bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir}
bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data}
fi
# convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin

@ -11,7 +11,7 @@ data:
max_output_input_ratio: .inf
collator:
batch_size: 15
batch_size: 20
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
@ -45,7 +45,7 @@ model:
training:
n_epoch: 50
accum_grad: 4
accum_grad: 1
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06

@ -20,7 +20,8 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi

@ -19,17 +19,17 @@
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
"replace_with_zero": true
},
"prob": 1.0
}

@ -33,7 +33,7 @@ collator:
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
num_workers: 2
# network architecture
@ -74,7 +74,7 @@ model:
training:
n_epoch: 120
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam

@ -19,7 +19,8 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
@ -31,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} != 0]; then
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi

@ -19,7 +19,8 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi

@ -0,0 +1,3 @@
# Ngram LM
* s0 - kenlm ngram lm

@ -2,6 +2,95 @@
Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm).
## Run
```
. path.sh
bash run.sh
```
## Results
```
exp/
|-- text
|-- text.char.tn
|-- text.word.tn
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
|-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
`-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
0 directories, 7 files
```
```
3ae083627b9b6cef1a82d574d8483f97 exp/text
d97da252d2a63a662af22f98af30cb8c exp/text.char.tn
c18b03005bd094dbfd9b46442be361fd exp/text.word.tn
73dbf50097896eda33985e11e1ba9a3a exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa
01334e2044c474b99c4f2ffbed790626 exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin
36a42de548045b54662411ae7982c77f exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa
332422803ffd73dd7ffd16cd2b0abcd5 exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin
```
```
==> exp/text <==
少先队员因该为老人让坐
祛痘印可以吗?有效果吗?
不知这款牛奶口感怎样? 小孩子喝行吗!
是转基因油?
我家宝宝13斤用多大码的
会起坨吗?
请问给送上楼吗?
亲是送赁上门吗
送货时候有外包装没有还是直接发货过来
会不会有坏的?
==> exp/text.char.tn <==
少 先 队 员 因 该 为 老 人 让 坐
祛 痘 印 可 以 吗 有 效 果 吗
不 知 这 款 牛 奶 口 感 怎 样 小 孩 子 喝 行 吗
是 转 基 因 油
我 家 宝 宝 十 三 斤 用 多 大 码 的
会 起 坨 吗
请 问 给 送 上 楼 吗
亲 是 送 赁 上 门 吗
送 货 时 候 有 外 包 装 没 有 还 是 直 接 发 货 过 来
会 不 会 有 坏 的
==> exp/text.word.tn <==
少先队员 因该 为 老人 让 坐
祛痘 印 可以 吗 有 效果 吗
不知 这 款 牛奶 口感 怎样 小孩子 喝行 吗
是 转基因 油
我家 宝宝 十三斤 用多大码 的
会起 坨 吗
请问 给 送 上楼 吗
亲是 送赁 上门 吗
送货 时候 有 外包装 没有 还是 直接 发货 过来
会 不会 有坏 的
==> exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa <==
\data\
ngram 1=587
ngram 2=395
ngram 3=100
ngram 4=2
ngram 5=0
\1-grams:
-3.272324 <unk> 0
0 <s> -0.36706257
==> exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa <==
\data\
ngram 1=689
ngram 2=1398
ngram 3=1506
\1-grams:
-3.1755018 <unk> 0
0 <s> -0.23069073
-1.2318869 </s> 0
-3.067262 少先队员 -0.051341705
```

@ -1,7 +1,96 @@
# [SentencePiece Model](https://github.com/google/sentencepiece)
## Run
Train a `spm` model for English tokenizer.
```
. path.sh
bash run.sh
```
## Results
```
data/
└── lang_char
├── input.bpe
├── input.decode
├── input.txt
├── train_unigram100.model
├── train_unigram100_units.txt
└── train_unigram100.vocab
1 directory, 6 files
```
```
b5a230c26c61db5c36f34e503102f936 data/lang_char/input.bpe
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.decode
ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.txt
124bf3fe7ce3b73b1994234c15268577 data/lang_char/train_unigram100.model
0df2488cc8eaace95eb12713facb5cf0 data/lang_char/train_unigram100_units.txt
46360cac35c751310e8e8ffd3a034cb5 data/lang_char/train_unigram100.vocab
```
```
==> data/lang_char/input.bpe <==
▁mi ster ▁quilter ▁ is ▁the ▁a p ost le ▁o f ▁the ▁mi d d le ▁c las s es ▁ and ▁we ▁ar e ▁g l a d ▁ to ▁we l c om e ▁h is ▁g o s pe l
▁ n or ▁ is ▁mi ster ▁quilter ' s ▁ma nne r ▁ l ess ▁in ter es t ing ▁tha n ▁h is ▁ma t ter
▁h e ▁ t e ll s ▁us ▁tha t ▁ at ▁ t h is ▁f es t ive ▁ s e ason ▁o f ▁the ▁ y e ar ▁w ith ▁ ch r is t m a s ▁ and ▁ro a s t ▁be e f ▁ l o om ing ▁be fore ▁us ▁ s i mile s ▁d r a w n ▁f r om ▁ e at ing ▁ and ▁it s ▁re s u l t s ▁o c c ur ▁m ost ▁re a di l y ▁ to ▁the ▁ mind
▁h e ▁ ha s ▁g r a v e ▁d o u b t s ▁w h e t h er ▁ s i r ▁f r e d er ic k ▁ l eig h to n ' s ▁w or k ▁ is ▁re all y ▁gre e k ▁a f ter ▁ all ▁ and ▁c a n ▁di s c o v er ▁in ▁it ▁b u t ▁li t t le ▁o f ▁ro ck y ▁it ha c a
▁li nne ll ' s ▁ p ic tur es ▁ar e ▁a ▁ s or t ▁o f ▁ u p ▁g u ar d s ▁ and ▁ at ▁ em ▁painting s ▁ and ▁m ason ' s ▁ e x q u is i t e ▁ i d y ll s ▁ar e ▁a s ▁ n at ion a l ▁a s ▁a ▁ j ing o ▁ p o em ▁mi ster ▁b i r k e t ▁f o ster ' s ▁ l and s c a pe s ▁ s mile ▁ at ▁on e ▁m u ch ▁in ▁the ▁ s a m e ▁w a y ▁tha t ▁mi ster ▁c ar k er ▁us e d ▁ to ▁f las h ▁h is ▁ t e e t h ▁ and ▁mi ster ▁ j o h n ▁c o ll i er ▁g ive s ▁h is ▁ s i t ter ▁a ▁ ch e er f u l ▁ s l a p ▁on ▁the ▁b a ck ▁be fore ▁h
e ▁ s a y s ▁li k e ▁a ▁ s ha m p o o er ▁in ▁a ▁ tur k is h ▁b at h ▁ n e x t ▁ma n
▁it ▁ is ▁o b v i o u s l y ▁ u nne c ess ar y ▁for ▁us ▁ to ▁ p o i n t ▁o u t ▁h o w ▁ l u m i n o u s ▁the s e ▁c rit ic is m s ▁ar e ▁h o w ▁d e l ic at e ▁in ▁ e x p r ess ion
▁on ▁the ▁g e n er a l ▁ p r i n c i p l es ▁o f ▁ar t ▁mi ster ▁quilter ▁w rit es ▁w ith ▁ e qual ▁ l u c i di t y
▁painting ▁h e ▁ t e ll s ▁us ▁ is ▁o f ▁a ▁di f f er e n t ▁ qual i t y ▁ to ▁ma t h em at ic s ▁ and ▁f i nish ▁in ▁ar t ▁ is ▁a d d ing ▁m or e ▁f a c t
▁a s ▁for ▁ e t ch ing s ▁the y ▁ar e ▁o f ▁ t w o ▁ k i n d s ▁b rit is h ▁ and ▁for eig n
▁h e ▁ l a ment s ▁m ost ▁b i t ter l y ▁the ▁di v or c e ▁tha t ▁ ha s ▁be e n ▁ma d e ▁be t w e e n ▁d e c or at ive ▁ar t ▁ and ▁w ha t ▁we ▁us u all y ▁c all ▁ p ic tur es ▁ma k es ▁the ▁c u s t om ar y ▁a p pe a l ▁ to ▁the ▁ las t ▁ j u d g ment ▁ and ▁re mind s ▁us ▁tha t ▁in ▁the ▁gre at ▁d a y s ▁o f ▁ar t ▁mi c ha e l ▁a n g e l o ▁w a s ▁the ▁f ur nish ing ▁ u p h o l ster er
==> data/lang_char/input.decode <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/input.txt <==
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter's manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man
it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression
on the general principles of art mister quilter writes with equal lucidity
painting he tells us is of a different quality to mathematics and finish in art is adding more fact
as for etchings they are of two kinds british and foreign
he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer
==> data/lang_char/train_unigram100_units.txt <==
<blank> 0
<unk> 1
' 2
a 3
all 4
and 5
ar 6
ason 7
at 8
b 9
==> data/lang_char/train_unigram100.vocab <==
<unk> 0
<s> 0
</s> 0
▁ -2.01742
e -2.7203
s -2.82989
t -2.99689
l -3.53267
n -3.84935
o -3.88229
```

@ -19,7 +19,8 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
@ -31,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \
--seed ${seed}
if [ ${seed} != 0 ]; then
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi

@ -1,3 +0,0 @@
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.

@ -19,7 +19,8 @@ echo "using ${device}..."
mkdir -p exp
seed=10086
# seed may break model convergence
seed=0
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi

@ -48,7 +48,7 @@ training:
n_epoch: 10
accum_grad: 1
lr: 1e-5
lr_decay: 1.0
lr_decay: 0.8
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1

@ -1,35 +1,44 @@
#!/bin/bash
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
profiler_options=
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
mkdir -p exp
seed=10086
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
config_path=$1
ckpt_name=$2
model_type=$3
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--model_type ${model_type} \
--profiler-options "${profiler_options}" \
--seed ${seed}
if [ ${seed} != 0 ]; then

@ -1,33 +0,0 @@
#!/bin/bash
if [ $# != 1 ];then
echo "usage: tune ckpt_path"
exit 1
fi
# grid-search for hyper-parameters in language model
python3 -u ${BIN_DIR}/tune.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--num_batches=-1 \
--batch_size=128 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_alphas=45 \
--num_betas=8 \
--alpha_from=1.0 \
--alpha_to=3.2 \
--beta_from=0.1 \
--beta_to=0.45 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--checkpoint_path ${1}
if [ $? -ne 0 ]; then
echo "Failed in tuning!"
exit 1
fi
exit 0

@ -19,17 +19,17 @@
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true,
"warp_mode": "PIL"
"replace_with_zero": true
},
"prob": 1.0
}

@ -1,36 +1,49 @@
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
device=gpu
if [ ${ngpu} == 0 ];then
device=cpu
fi
mkdir -p exp
seed=10086
if [ ${seed} != 0 ]; then
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
exit -1
fi
config_path=$1
ckpt_name=$2
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--seed ${seed} \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
if [ ${seed} != 0 ]; then
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi

@ -0,0 +1 @@
exp

@ -0,0 +1,36 @@
# Regular expression based text normalization for Chinese
For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example.
## Run
```
. path.sh
bash run.sh
```
## Results
```
exp/
`-- normalized.txt
0 directories, 1 file
```
```
aff31f8aa08e2a7360228c9ce5886b98 exp/normalized.txt
```
```
今天的最低气温达到零下十度.
只要有四分之三十三的人同意,就可以通过决议。
一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
如果剩下的百分之三十点六是过去,那么还有百分之六十九点四.
事情发生在二零二零年三月三十一日的上午八点.
警方正在找一支点二二口径的手枪。
欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务
充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。
快速解除流量封顶请按星号键腾讯王卡产品介绍、使用说明、特权及活动请按九查询话费、套餐余量、积分及活动返款请按一手机上网流量开通及取消请按二<EFBFBD><EFBFBD><EFBFBD>本机号码及本号所使用套餐请按四密码修改及重置请按五紧急开机请按六挂失请按七查询充值记录请按八其它自助服务及工服务请按零
```

@ -0,0 +1,2 @@
old-pd_env.txt
pd_env.txt

@ -0,0 +1,11 @@
# Benchmark Test
## Data
* Aishell
## Docker
```
registry.baidubce.com/paddlepaddle/paddle 2.1.1-gpu-cuda10.2-cudnn7 59d5ec1de486
```

@ -0,0 +1,49 @@
#!/bin/bash
CUR_DIR=${PWD}
ROOT_DIR=../../
# 提供可稳定复现性能的脚本默认在标准docker环境内py37执行
# collect env info
bash ${ROOT_DIR}/utils/pd_env_collect.sh
#cat pd_env.txt
# 1 安装该模型需要的依赖 (如需开启优化策略请注明)
#pushd ${ROOT_DIR}/tools; make; popd
#source ${ROOT_DIR}/tools/venv/bin/activate
#pushd ${ROOT_DIR}; bash setup.sh; popd
# 2 拷贝该模型需要数据、预训练模型
# 执行目录:需说明
#pushd ${ROOT_DIR}/examples/aishell/s1
pushd ${ROOT_DIR}/examples/tiny/s1
mkdir -p exp/log
. path.sh
#bash local/data.sh &> exp/log/data.log
# 3 批量运行如不方便批量12需放到单个模型中
model_mode_list=(conformer transformer)
fp_item_list=(fp32)
bs_item_list=(32 64 96)
for model_mode in ${model_mode_list[@]}; do
for fp_item in ${fp_item_list[@]}; do
for bs_item in ${bs_item_list[@]}
do
echo "index is speed, 1gpus, begin, ${model_name}"
run_mode=sp
CUDA_VISIBLE_DEVICES=0 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} # (5min)
sleep 60
echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}"
run_mode=mp
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode}
sleep 60
done
done
done
popd # aishell/s1

@ -0,0 +1,57 @@
#!/bin/bash
set -xe
# 运行示例CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode}
# 参数说明
function _set_params(){
run_mode=${1:-"sp"} # 单卡sp|多卡mp
batch_size=${2:-"64"}
fp_item=${3:-"fp32"} # fp32|fp16
max_iter=${4:-"500"} # 可选,如果需要修改代码提前中断
model_name=${5:-"model_name"}
run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数
# 以下不用修改
device=${CUDA_VISIBLE_DEVICES//,/ }
arr=(${device})
num_gpu_devices=${#arr[*]}
log_file=${run_log_path}/${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
}
function _train(){
echo "Train on ${num_gpu_devices} GPUs"
echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size"
train_cmd="--benchmark-batch-size ${batch_size}
--benchmark-max-step ${max_iter}
conf/${model_name}.yaml ${model_name}"
case ${run_mode} in
sp) train_cmd="bash local/train.sh "${train_cmd}"" ;;
mp)
train_cmd="bash local/train.sh "${train_cmd}"" ;;
*) echo "choose run_mode(sp or mp)"; exit 1;
esac
# 以下不用修改
timeout 15m ${train_cmd} > ${log_file} 2>&1
if [ $? -ne 0 ];then
echo -e "${model_name}, FAIL"
export job_fail_flag=1
else
echo -e "${model_name}, SUCCESS"
export job_fail_flag=0
fi
trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM
if [ $run_mode = "mp" -a -d mylog ]; then
rm ${log_file}
cp mylog/workerlog.0 ${log_file}
fi
}
_set_params $@
_train

@ -1,146 +0,0 @@
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
def frame(x: Tensor,
num_samples: Tensor,
win_length: int,
hop_length: int,
clip: bool = True) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (N, T), batched waveform.
num_samples : Tensor
Shape (N, ), number of samples of each waveform.
win_length : int
Window length.
hop_length : int
Number of samples shifted between ajancent frames.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (N, T', win_length).
num_frames : Tensor
Shape (N, ) number of valid frames
"""
assert hop_length <= win_length
num_frames = (num_samples - win_length) // hop_length
padding = (0, 0)
if not clip:
num_frames += 1
# NOTE: pad hop_length - 1 to the right to ensure that there is at most
# one frame dangling to the righe edge
padding = (0, hop_length - 1)
weight = paddle.eye(win_length).unsqueeze(1)
frames = F.conv1d(x.unsqueeze(1),
weight,
padding=padding,
stride=(hop_length, ))
return frames, num_frames
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
Parameters
------------
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
hop_length: int,
win_length: int,
window_type: str = None,
clip: bool = True):
super().__init__()
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
self.clip = clip
# calculate window
if window_type is None:
window = np.ones(win_length)
elif window_type == "hann":
window = np.hanning(win_length)
elif window_type == "hamming":
window = np.hamming(win_length)
else:
raise ValueError("Not supported yet!")
if win_length < n_fft:
window = F.pad(window, (0, n_fft - win_length))
elif win_length > n_fft:
window = window[:n_fft]
# (n_bins, n_fft) complex
kernel_size = min(n_fft, win_length)
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor
Number of samples of each waveform.
Returns
------------
D : Tensor
Shape(N, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (N,) number of samples of each spectrogram
"""
num_frames = (num_samples - self.win_length) // self.hop_length
padding = (0, 0)
if not self.clip:
num_frames += 1
padding = (0, self.hop_length - 1)
batch_size, _, _ = paddle.shape(x)
x = x.unsqueeze(-1)
D = F.conv1d(self.weight,
x,
stride=(self.hop_length, ),
padding=padding,
data_format="NLC")
D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2])
return D, num_frames

@ -0,0 +1,201 @@
import paddle
import numpy as np
from typing import Tuple, Optional, Union
# https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/src/feat/feature-window.cc#L109
def povey_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = (0.5 - 0.5 * np.cos(a * i) )**0.85
return win
def hann_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.5 - 0.5 * np.cos(a * i)
return win
def sine_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = np.sin(0.5 * a * i)
return win
def hamm_window(frame_len:int) -> np.ndarray:
win = np.empty(frame_len)
a = 2 * np.pi / (frame_len -1)
for i in range(frame_len):
win[i] = 0.54 - 0.46 * np.cos(a * i)
return win
def get_window(wintype:Optional[str], winlen:int) -> np.ndarray:
"""get window function
Args:
wintype (Optional[str]): window type.
winlen (int): window length in samples.
Raises:
ValueError: not support window.
Returns:
np.ndarray: window coeffs.
"""
# calculate window
if not wintype or wintype == 'rectangular':
window = np.ones(winlen)
elif wintype == "hann":
window = hann_window(winlen)
elif wintype == "hamm":
window = hamm_window(winlen)
elif wintype == "povey":
window = povey_window(winlen)
else:
msg = f"{wintype} Not supported yet!"
raise ValueError(msg)
return window
def dft_matrix(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# https://en.wikipedia.org/wiki/Discrete_Fourier_transform
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
n = np.arange(0, n_fft, 1.)
wsin = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
wcos = np.empty((n_bin, kernel_size)) #[Cout, kernel_size]
for k in range(n_bin): # Only half of the bins contain useful info
wsin[k,:] = -np.sin(2*np.pi*k*n/n_fft)[:kernel_size]
wcos[k,:] = np.cos(2*np.pi*k*n/n_fft)[:kernel_size]
w_real = wcos
w_imag = wsin
return w_real, w_imag, kernel_size
def dft_matrix_fast(n_fft:int, winlen:int=None, n_bin:int=None) -> Tuple[np.ndarray, np.ndarray, int]:
# (n_bins, n_fft) complex
if n_bin is None:
n_bin = 1 + n_fft // 2
if winlen is None:
winlen = n_bin
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/fft/_pocketfft.py#L49
kernel_size = min(n_fft, winlen)
# https://en.wikipedia.org/wiki/DFT_matrix
# https://ccrma.stanford.edu/~jos/st/Matrix_Formulation_DFT.html
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin, :kernel_size]
w_real = weight.real
w_imag = weight.imag
return w_real, w_imag, kernel_size
def bin2hz(bin:Union[List[int], np.ndarray], N:int, sr:int)->List[float]:
"""FFT bins to Hz.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Args:
bins (List[int] or np.ndarray): bin index.
N (int): the number of samples, or FFT points.
sr (int): sampling rate.
Returns:
List[float]: Hz's.
"""
hz = bin * float(sr) / N
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * np.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (np.exp(mel/1127.0)-1)
def rms_to_db(rms: float):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return 20.0 * math.log10(max(1e-16, rms))
def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return rms_to_db(rms) - 3.0103
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
def gain_db_to_ratio(gain_db: float):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return math.pow(10.0, gain_db / 20.0)

Binary file not shown.

@ -0,0 +1,266 @@
from typing import Tuple
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
import soundfile as sf
from .common import get_window
from .common import dft_matrix
def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2d=True)->Tuple[int, np.ndarray]:
"""load wav file.
Args:
wavpath (str): wav path.
sr (int, optional): expect sample rate. Defaults to None.
dtype (str, optional): wav data bits. Defaults to 'int16'.
Returns:
Tuple[int, np.ndarray]: sr (int), wav (int16) [T, C].
"""
wav, r_sr = sf.read(wavpath, start=start, stop=stop, dtype=dtype, always_2d=always_2d)
if sr:
assert sr == r_sr
return r_sr, wav
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
"""write wav file.
Args:
wavpath (str): file path to save.
wav (np.ndarray): wav data.
sr (int): data samplerate.
dtype (str, optional): wav bit format. Defaults to 'PCM_16'.
"""
sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor,
num_samples: Tensor,
sr: int,
win_length: float,
stride_length: float,
clip: bool = False) -> Tuple[Tensor, Tensor]:
"""Extract frames from audio.
Parameters
----------
x : Tensor
Shape (B, T), batched waveform.
num_samples : Tensor
Shape (B, ), number of samples of each waveform.
sr: int
Sampling Rate.
win_length : float
Window length in ms.
stride_length : float
Stride length in ms.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
default True
Returns
-------
frames : Tensor
Shape (B, T', win_length).
num_frames : Tensor
Shape (B, ) number of valid frames
"""
assert stride_length <= win_length
stride_length = int(stride_length * sr)
win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length
padding = (0, 0)
if not clip:
num_frames += 1
need_samples = num_frames * stride_length + win_length
padding = (0, need_samples - num_samples - 1)
weight = paddle.eye(win_length).unsqueeze(1) #[win_length, 1, win_length]
frames = F.conv1d(x.unsqueeze(-1),
weight,
padding=padding,
stride=(stride_length, ),
data_format='NLC')
return frames, num_frames
def dither(signal:Tensor, dither_value=1.0)->Tensor:
"""dither frames for log compute.
Args:
signal (Tensor): [B, T, D]
dither_value (float, optional): [scalar]. Defaults to 1.0.
Returns:
Tensor: [B, T, D]
"""
D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal
def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc.
Args:
signal (Tensor): [B, T, D]
Returns:
Tensor: [B, T, D]
"""
signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
"""perform preemphasis on the input signal.
Args:
signal (Tensor): [B, T, D], The signal to filter.
coeff (float, optional): [scalar].The preemphasis coefficient. 0 is no filter, Defaults to 0.97.
Returns:
Tensor: [B, T, D]
"""
return paddle.concat([
(1-coeff)*signal[:, :, 0:1],
signal[:, :, 1:] - coeff * signal[:, :, :-1]
], axis=-1)
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters
------------
n_fft : int
Number of samples in a frame.
sr: int
Number of Samplilng rate.
stride_length : float
Number of samples shifted between adjacent frames.
win_length : float
Length of the window.
clip: bool
Whether to clip audio is necesaary.
"""
def __init__(self,
n_fft: int,
sr: int,
win_length: float,
stride_length: float,
dither:float=0.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
clip: bool = False):
super().__init__()
self.sr = sr
self.win_length = win_length
self.stride_length = stride_length
self.dither = dither
self.preemph_coeff = preemph_coeff
self.remove_dc_offset = remove_dc_offset
self.window_type = window_type
self.clip = clip
self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin
)
# calculate window
window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (kernel_size, 2 * n_bins)
w = np.transpose(w)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
Parameters
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor [shape=(B,)]
Number of samples of each waveform.
Returns
------------
C : Tensor
Shape(B, T', n_bins, 2) Spectrogram.
num_frames: Tensor
Shape (B,) number of samples of each spectrogram
"""
batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither:
F = dither(F, self.dither)
if self.remove_dc_offset:
F = remove_dc_offset(F)
if self.preemph_coeff:
F = preemphasis(F)
C = paddle.matmul(F, self.weight) # [B, T, K] [K, 2 * n_bins]
C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin])
C = C.transpose([0, 1, 3, 2])
return C, nframe
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum |X_k|^2.
Args:
C (Tensor): [B, T, C, 2]
Returns:
Tensor: [B, T, C]
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum |X_k|.
Args:
C (Tensor): [B, T, C, 2]
eps (float): epsilon.
Returns:
Tensor: [B, T, C]
"""
pspec = powspec(C)
return paddle.sqrt(pspec + eps)
def logspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute log-spectrum 20log10X_k.
Args:
C (Tensor): [description]
eps ([type], optional): [description]. Defaults to 1e-10.
Returns:
Tensor: [description]
"""
spec = magspec(C)
return 20 * paddle.log10(spec + eps)

@ -0,0 +1,533 @@
from typing import Tuple
import numpy as np
import paddle
import unittest
import decimal
import numpy
import math
import logging
from pathlib import Path
from scipy.fftpack import dct
from third_party.paddle_audio.frontend import kaldi
def round_half_up(number):
return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))
def rolling_window(a, window, step=1):
# http://ellisvalentiner.com/post/2017-03-21-np-strides-trick
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]
def do_dither(signal, dither_value=1.0):
signal += numpy.random.normal(size=signal.shape) * dither_value
return signal
def do_remove_dc_offset(signal):
signal -= numpy.mean(signal)
return signal
def do_preemphasis(signal, coeff=0.97):
"""perform preemphasis on the input signal.
:param signal: The signal to filter.
:param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.
:returns: the filtered signal.
"""
return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])
def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = numpy.zeros(frames.shape)
for frm in range(frames.shape[0]):
frames[frm,:] = do_dither(frames[frm,:], dither) # dither
frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset
raw_frames[frm,:] = frames[frm,:]
frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize
return frames * win, raw_frames
def magspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return numpy.absolute(complex_spec)
def powspec(frames, NFFT):
"""Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.
"""
return numpy.square(magspec(frames, NFFT))
def mfcc(signal,samplerate=16000,winlen=0.025,winstep=0.01,numcep=13,
nfilt=23,nfft=512,lowfreq=20,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,
ceplifter=22,useEnergy=True,wintype='povey'):
"""Compute MFCC features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param numcep: the number of cepstrum to return, default 13
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param ceplifter: apply a lifter to final cepstral coefficients. 0 is no lifter. Default is 22.
:param appendEnergy: if this is true, the zeroth cepstral coefficient is replaced with the log of the total frame energy.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
:returns: A numpy array of size (NUMFRAMES by numcep) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither,remove_dc_offset,preemph,wintype)
feat = numpy.log(feat)
feat = dct(feat, type=2, axis=1, norm='ortho')[:,:numcep]
feat = lifter(feat,ceplifter)
if useEnergy: feat[:,0] = numpy.log(energy) # replace first cepstral coefficient with log of frame energy
return feat
def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
"""Compute Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:param winfunc: the analysis window to apply to each frame. By default no window is applied. You can use numpy window functions here e.g. winfunc=numpy.hamming
winfunc=lambda x:numpy.ones((x,))
:returns: 2 values. The first is a numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector. The
second return value is the energy in each frame (total energy, unwindowed)
"""
highfreq= highfreq or samplerate/2
frames,raw_frames = sigproc.framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)
pspec = sigproc.powspec(frames,nfft) # nearly the same until this part
energy = numpy.sum(raw_frames**2,1) # this stores the raw energy in each frame
energy = numpy.where(energy == 0,numpy.finfo(float).eps,energy) # if energy is zero, we get problems with log
fb = get_filterbanks(nfilt,nfft,samplerate,lowfreq,highfreq)
feat = numpy.dot(pspec,fb.T) # compute the filterbank energies
feat = numpy.where(feat == 0,numpy.finfo(float).eps,feat) # if feat is zero, we get problems with log
return feat,energy
def logfbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=64,highfreq=None,dither=1.0,remove_dc_offset=True,preemph=0.97,wintype='hamming'):
"""Compute log Mel-filterbank energy features from an audio signal.
:param signal: the audio signal from which to compute features. Should be an N*1 array
:param samplerate: the samplerate of the signal we are working with.
:param winlen: the length of the analysis window in seconds. Default is 0.025s (25 milliseconds)
:param winstep: the step between successive windows in seconds. Default is 0.01s (10 milliseconds)
:param nfilt: the number of filters in the filterbank, default 26.
:param nfft: the FFT size. Default is 512.
:param lowfreq: lowest band edge of mel filters. In Hz, default is 0.
:param highfreq: highest band edge of mel filters. In Hz, default is samplerate/2
:param preemph: apply preemphasis filter with preemph as coefficient. 0 is no filter. Default is 0.97.
:returns: A numpy array of size (NUMFRAMES by nfilt) containing features. Each row holds 1 feature vector.
"""
feat,energy = fbank(signal,samplerate,winlen,winstep,nfilt,nfft,lowfreq,highfreq,dither, remove_dc_offset,preemph,wintype)
return numpy.log(feat)
def hz2mel(hz):
"""Convert a value in Hertz to Mels
:param hz: a value in Hz. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Mels. If an array was passed in, an identical sized array is returned.
"""
return 1127 * numpy.log(1+hz/700.0)
def mel2hz(mel):
"""Convert a value in Mels to Hertz
:param mel: a value in Mels. This can also be a numpy array, conversion proceeds element-wise.
:returns: a value in Hertz. If an array was passed in, an identical sized array is returned.
"""
return 700 * (numpy.exp(mel/1127.0)-1)
def get_filterbanks(nfilt=26,nfft=512,samplerate=16000,lowfreq=0,highfreq=None):
"""Compute a Mel-filterbank. The filters are stored in the rows, the columns correspond
to fft bins. The filters are returned as an array of size nfilt * (nfft/2 + 1)
:param nfilt: the number of filters in the filterbank, default 20.
:param nfft: the FFT size. Default is 512.
:param samplerate: the samplerate of the signal we are working with. Affects mel spacing.
:param lowfreq: lowest band edge of mel filters, default 0 Hz
:param highfreq: highest band edge of mel filters, default samplerate/2
:returns: A numpy array of size nfilt * (nfft/2 + 1) containing filterbank. Each row holds 1 filter.
"""
highfreq= highfreq or samplerate/2
assert highfreq <= samplerate/2, "highfreq is greater than samplerate/2"
# compute points evenly spaced in mels
lowmel = hz2mel(lowfreq)
highmel = hz2mel(highfreq)
# check kaldi/src/feat/Mel-computations.h
fbank = numpy.zeros([nfilt,nfft//2+1])
mel_freq_delta = (highmel-lowmel)/(nfilt+1)
for j in range(0,nfilt):
leftmel = lowmel+j*mel_freq_delta
centermel = lowmel+(j+1)*mel_freq_delta
rightmel = lowmel+(j+2)*mel_freq_delta
for i in range(0,nfft//2):
mel=hz2mel(i*samplerate/nfft)
if mel>leftmel and mel<rightmel:
if mel<centermel:
fbank[j,i]=(mel-leftmel)/(centermel-leftmel)
else:
fbank[j,i]=(rightmel-mel)/(rightmel-centermel)
return fbank
def lifter(cepstra, L=22):
"""Apply a cepstral lifter the the matrix of cepstra. This has the effect of increasing the
magnitude of the high frequency DCT coeffs.
:param cepstra: the matrix of mel-cepstra, will be numframes * numcep in size.
:param L: the liftering coefficient to use. Default is 22. L <= 0 disables lifter.
"""
if L > 0:
nframes,ncoeff = numpy.shape(cepstra)
n = numpy.arange(ncoeff)
lift = 1 + (L/2.)*numpy.sin(numpy.pi*n/L)
return lift*cepstra
else:
# values of L <= 0, do nothing
return cepstra
def delta(feat, N):
"""Compute delta features from a feature vector sequence.
:param feat: A numpy array of size (NUMFRAMES by number of features) containing features. Each row holds 1 feature vector.
:param N: For each frame, calculate delta features based on preceding and following N frames
:returns: A numpy array of size (NUMFRAMES by number of features) containing delta features. Each row holds 1 delta feature vector.
"""
if N < 1:
raise ValueError('N must be an integer >= 1')
NUMFRAMES = len(feat)
denominator = 2 * sum([i**2 for i in range(1, N+1)])
delta_feat = numpy.empty_like(feat)
padded = numpy.pad(feat, ((N, N), (0, 0)), mode='edge') # padded version of feat
for t in range(NUMFRAMES):
delta_feat[t] = numpy.dot(numpy.arange(-N, N+1), padded[t : t+2*N+1]) / denominator # [t : t+2*N+1] == [(N+t)-N : (N+t)+N+1]
return delta_feat
##### modify for test ######
def framesig_without_dither_dc_preemphasize(sig, frame_len, frame_step, wintype='hamming', stride_trick=True):
"""Frame a signal into overlapping frames.
:param sig: the audio signal to frame.
:param frame_len: length of each frame measured in samples.
:param frame_step: number of samples after the start of the previous frame that the next frame should begin.
:param winfunc: the analysis window to apply to each frame. By default no window is applied.
:param stride_trick: use stride trick to compute the rolling window and window multiplication faster
:returns: an array of frames. Size is NUMFRAMES by frame_len.
"""
slen = len(sig)
frame_len = int(round_half_up(frame_len))
frame_step = int(round_half_up(frame_step))
if slen <= frame_len:
numframes = 1
else:
numframes = 1 + (( slen - frame_len) // frame_step)
# check kaldi/src/feat/feature-window.h
padsignal = sig[:(numframes-1)*frame_step+frame_len]
if wintype is 'povey':
win = numpy.empty(frame_len)
for i in range(frame_len):
win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85
elif wintype == '':
win = numpy.ones(frame_len)
elif wintype == 'hann':
win = numpy.hanning(frame_len)
else: # the hamming window
win = numpy.hamming(frame_len)
if stride_trick:
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
else:
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
indices = numpy.array(indices, dtype=numpy.int32)
frames = padsignal[indices]
win = numpy.tile(win, (numframes, 1))
frames = frames.astype(numpy.float32)
raw_frames = frames
return frames * win, raw_frames
def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None, wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
return frames_with_win, raw_frames
def complexspec(frames, NFFT):
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
:param frames: the array of frames. Each row is a frame.
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
"""
if numpy.shape(frames)[1] > NFFT:
logging.warn(
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
numpy.shape(frames)[1], NFFT)
complex_spec = numpy.fft.rfft(frames, NFFT)
return complex_spec
def stft_with_window(signal,samplerate=16000,winlen=0.025,winstep=0.01,
nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97,
wintype='hamming'):
frames_with_win, raw_frames = framesig_without_dither_dc_preemphasize(signal, winlen*samplerate, winstep*samplerate, wintype)
spec = magspec(frames_with_win, nfft) # nearly the same until this part
scomplex = complexspec(frames_with_win, nfft)
rspec = magspec(raw_frames, nfft)
rcomplex = complexspec(raw_frames, nfft)
return spec, scomplex, rspec, rcomplex
class TestKaldiFE(unittest.TestCase):
def setUp(self):
self. this_dir = Path(__file__).parent
self.wavpath = str(self.this_dir / 'english.wav')
self.winlen=0.025 # ms
self.winstep=0.01 # ms
self.nfft=512
self.lowfreq = 0
self.highfreq = None
self.wintype='hamm'
self.nfilt=40
paddle.set_device('cpu')
def test_read(self):
import scipy.io.wavfile as wav
rate, sig = wav.read(self.wavpath)
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
self.assertTrue(np.all(sig == wav))
self.assertEqual(rate, sr)
def test_frames(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
_, fs = frames(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
t_fs, t_nframe = kaldi.frames(t_wav, t_wavlen, sr, self.winlen, self.winstep, clip=False)
t_fs = t_fs.astype(fs.dtype)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1]
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, dither=0.0, preemph_coeff=0.0, remove_dc_offset=False, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))
# from python_speech_features import mfcc
# from python_speech_features import delta
# from python_speech_features import logfbank
# import scipy.io.wavfile as wav
# (rate,sig) = wav.read("english.wav")
# # note that generally nfilt=40 is used for speech recognition
# fbank_feat = logfbank(sig,nfilt=23,lowfreq=20,dither=0,wintype='povey')
# # the computed fbank coefficents of english.wav with dimension [110,23]
# # [ 12.2865 12.6906 13.1765 15.714 16.064 15.7553 16.5746 16.9205 16.6472 16.1302 16.4576 16.7326 16.8864 17.7215 18.88 19.1377 19.1495 18.6683 18.3886 20.3506 20.2772 18.8248 18.1899
# # 11.9198 13.146 14.7215 15.8642 17.4288 16.394 16.8238 16.1095 16.4297 16.6331 16.3163 16.5093 17.4981 18.3429 19.6555 19.6263 19.8435 19.0534 19.001 20.0287 19.7707 19.5852 19.1112
# # ...
# # ...
# # the same with that using kaldi commands: compute-fbank-feats --dither=0.0
# mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')
# # the computed mfcc coefficents of english.wav with dimension [110,13]
# # [ 17.1337 -23.3651 -7.41751 -7.73686 -21.3682 -8.93884 -3.70843 4.68346 -16.0676 12.782 -7.24054 8.25089 10.7292
# # 17.1692 -23.3028 -5.61872 -4.0075 -23.287 -20.6101 -5.51584 -6.15273 -14.4333 8.13052 -0.0345329 2.06274 -0.564298
# # ...
# # ...
# # the same with that using kaldi commands: compute-mfcc-feats --dither=0.0
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,167 @@
#!/usr/bin/env bash
unset GREP_OPTIONS
set -u # Check for undefined variables
die() {
# Print a message and exit with code 1.
#
# Usage: die <error_message>
# e.g., die "Something bad happened."
echo $@
exit 1
}
echo "Collecting system information..."
OUTPUT_FILE=pd_env.txt
python_bin_path=$(which python || which python3 || die "Cannot find Python binary")
{
echo
echo '== check python ==================================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_python.py
import platform
print("""python version: %s
python branch: %s
python build version: %s
python compiler version: %s
python implementation: %s
""" % (
platform.python_version(),
platform.python_branch(),
platform.python_build(),
platform.python_compiler(),
platform.python_implementation(),
))
EOF
${python_bin_path} /tmp/check_python.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== check os platform ==============================================='
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_os.py
import platform
print("""os: %s
os kernel version: %s
os release version: %s
os platform: %s
linux distribution: %s
linux os distribution: %s
mac version: %s
uname: %s
architecture: %s
machine: %s
""" % (
platform.system(),
platform.version(),
platform.release(),
platform.platform(),
platform.linux_distribution(),
platform.dist(),
platform.mac_ver(),
platform.uname(),
platform.architecture(),
platform.machine(),
))
EOF
${python_bin_path} /tmp/check_os.py 2>&1 >> ${OUTPUT_FILE}
{
echo
echo '== are we in docker ============================================='
num=`cat /proc/1/cgroup | grep docker | wc -l`;
if [ $num -ge 1 ]; then
echo "Yes"
else
echo "No"
fi
echo
echo '== compiler ====================================================='
c++ --version 2>&1
echo
echo '== check pips ==================================================='
pip list 2>&1 | grep "proto\|numpy\|paddlepaddle"
echo
echo '== check for virtualenv ========================================='
${python_bin_path} -c "import sys;print(hasattr(sys, \"real_prefix\"))"
echo
echo '== paddlepaddle import ============================================'
} >> ${OUTPUT_FILE}
cat <<EOF > /tmp/check_pd.py
import paddle as pd;
pd.set_device('cpu')
print("pd.version.full_version = %s" % pd.version.full_version)
print("pd.version.commit = %s" % pd.version.commit)
print("pd.__version__ = %s" % pd.__version__)
print("Sanity check: %r" % pd.zeros([1,2,3])[:1])
EOF
${python_bin_path} /tmp/check_pd.py 2>&1 >> ${OUTPUT_FILE}
LD_DEBUG=libs ${python_bin_path} -c "import paddle" 2>>${OUTPUT_FILE} > /tmp/loadedlibs
{
grep libcudnn.so /tmp/loadedlibs
echo
echo '== env =========================================================='
if [ -z ${LD_LIBRARY_PATH+x} ]; then
echo "LD_LIBRARY_PATH is unset";
else
echo LD_LIBRARY_PATH ${LD_LIBRARY_PATH} ;
fi
if [ -z ${DYLD_LIBRARY_PATH+x} ]; then
echo "DYLD_LIBRARY_PATH is unset";
else
echo DYLD_LIBRARY_PATH ${DYLD_LIBRARY_PATH} ;
fi
echo
echo '== nvidia-smi ==================================================='
nvidia-smi 2>&1
echo
echo '== cuda libs ==================================================='
} >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libcudart*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
find /usr/local -type f -name 'libudnn*' 2>/dev/null | grep cuda | grep -v "\\.cache" >> ${OUTPUT_FILE}
{
echo
echo '== paddlepaddle installed from info =================='
pip show paddlepaddle-gpu
echo
echo '== python version =============================================='
echo '(major, minor, micro, releaselevel, serial)'
python -c 'import sys; print(sys.version_info[:])'
echo
echo '== bazel version ==============================================='
bazel version
echo '== cmake version ==============================================='
cmake --version
} >> ${OUTPUT_FILE}
# Remove any words with google.
mv $OUTPUT_FILE old-$OUTPUT_FILE
grep -v -i google old-${OUTPUT_FILE} > $OUTPUT_FILE
echo "Wrote environment to ${OUTPUT_FILE}. You can review the contents of that file."
echo "and use it to populate the fields in the github issue template."
echo
echo "cat ${OUTPUT_FILE}"
echo
Loading…
Cancel
Save