commit
b3d27e4bbb
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for U2 model."""
|
||||
import cProfile
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.training.cli import default_argument_parser
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
from deepspeech.utils.utility import print_arguments
|
||||
|
||||
model_test_alias = {
|
||||
"u2": "deepspeech.exps.u2.model:U2Tester",
|
||||
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
|
||||
}
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
class_obj = dynamic_import(args.model_name, model_test_alias)
|
||||
exp = class_obj(config, args)
|
||||
exp.setup()
|
||||
|
||||
if args.run_mode == 'test':
|
||||
exp.run_test()
|
||||
elif args.run_mode == 'export':
|
||||
exp.run_export()
|
||||
elif args.run_mode == 'align':
|
||||
exp.run_align()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default='u2_kaldi',
|
||||
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
|
||||
parser.add_argument(
|
||||
'--run-mode',
|
||||
type=str,
|
||||
default='test',
|
||||
help='run mode, e.g. test, align, export')
|
||||
parser.add_argument(
|
||||
'--dict-path', type=str, default=None, help='dict path.')
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result-file", type=str, help="path of save the asr result")
|
||||
# save jit model to
|
||||
parser.add_argument(
|
||||
"--export-path", type=str, help="path of the jit model to save")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
|
||||
config = CfgNode()
|
||||
config.set_new_allowed(True)
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
# Setting for profiling
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(main, config, args)
|
||||
pr.dump_stats('test.profile')
|
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trainer for U2 model."""
|
||||
import cProfile
|
||||
import os
|
||||
|
||||
from paddle import distributed as dist
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.training.cli import default_argument_parser
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
from deepspeech.utils.utility import print_arguments
|
||||
|
||||
model_train_alias = {
|
||||
"u2": "deepspeech.exps.u2.model:U2Trainer",
|
||||
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
|
||||
}
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
class_obj = dynamic_import(args.model_name, model_train_alias)
|
||||
exp = class_obj(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.device == "gpu" and args.nprocs > 1:
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default='u2_kaldi',
|
||||
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
|
||||
config = CfgNode()
|
||||
config.set_new_allowed(True)
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
# Setting for profiling
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(main, config, args)
|
||||
pr.dump_stats(os.path.join(args.output, 'train.profile'))
|
@ -0,0 +1,654 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains U2 model."""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.frontend.featurizer import TextFeaturizer
|
||||
from deepspeech.frontend.utility import load_dict
|
||||
from deepspeech.io.dataloader import BatchDataLoader
|
||||
from deepspeech.models.u2 import U2Model
|
||||
from deepspeech.training.optimizer import OptimizerFactory
|
||||
from deepspeech.training.scheduler import LRSchedulerFactory
|
||||
from deepspeech.training.trainer import Trainer
|
||||
from deepspeech.utils import ctc_utils
|
||||
from deepspeech.utils import error_rate
|
||||
from deepspeech.utils import layer_tools
|
||||
from deepspeech.utils import mp_tools
|
||||
from deepspeech.utils import text_grid
|
||||
from deepspeech.utils import utility
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
_C = CfgNode()
|
||||
|
||||
_C.model = U2Model.params()
|
||||
|
||||
_C.training = U2Trainer.params()
|
||||
|
||||
_C.decoding = U2Tester.params()
|
||||
|
||||
config = _C.clone()
|
||||
config.set_new_allowed(True)
|
||||
return config
|
||||
|
||||
|
||||
class U2Trainer(Trainer):
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
# training config
|
||||
default = CfgNode(
|
||||
dict(
|
||||
n_epoch=50, # train epochs
|
||||
log_interval=100, # steps
|
||||
accum_grad=1, # accum grad by # steps
|
||||
checkpoint=dict(
|
||||
kbest_n=50,
|
||||
latest_n=5, ), ))
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
|
||||
def train_batch(self, batch_index, batch_data, msg):
|
||||
train_conf = self.config.training
|
||||
start = time.time()
|
||||
|
||||
utt, audio, audio_len, text, text_len = batch_data
|
||||
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
||||
text_len)
|
||||
# loss div by `batch_size * accum_grad`
|
||||
loss /= train_conf.accum_grad
|
||||
loss.backward()
|
||||
layer_tools.print_grads(self.model, print_func=None)
|
||||
|
||||
losses_np = {'loss': float(loss) * train_conf.accum_grad}
|
||||
if attention_loss:
|
||||
losses_np['att_loss'] = float(attention_loss)
|
||||
if ctc_loss:
|
||||
losses_np['ctc_loss'] = float(ctc_loss)
|
||||
|
||||
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.clear_grad()
|
||||
self.lr_scheduler.step()
|
||||
self.iteration += 1
|
||||
|
||||
iteration_time = time.time() - start
|
||||
|
||||
if (batch_index + 1) % train_conf.log_interval == 0:
|
||||
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
||||
msg += "batch size: {}, ".format(self.config.collator.batch_size)
|
||||
msg += "accum: {}, ".format(train_conf.accum_grad)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_np.items())
|
||||
logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0 and self.visualizer:
|
||||
losses_np_v = losses_np.copy()
|
||||
losses_np_v.update({"lr": self.lr_scheduler()})
|
||||
self.visualizer.add_scalars("step", losses_np_v,
|
||||
self.iteration - 1)
|
||||
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
self.model.eval()
|
||||
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||
valid_losses = defaultdict(list)
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
utt, audio, audio_len, text, text_len = batch
|
||||
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
||||
text_len)
|
||||
if paddle.isfinite(loss):
|
||||
num_utts = batch[1].shape[0]
|
||||
num_seen_utts += num_utts
|
||||
total_loss += float(loss) * num_utts
|
||||
valid_losses['val_loss'].append(float(loss))
|
||||
if attention_loss:
|
||||
valid_losses['val_att_loss'].append(float(attention_loss))
|
||||
if ctc_loss:
|
||||
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
||||
|
||||
if (i + 1) % self.config.training.log_interval == 0:
|
||||
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
||||
|
||||
# logging
|
||||
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_dump.items())
|
||||
logger.info(msg)
|
||||
|
||||
logger.info('Rank {} Val info val_loss {}'.format(
|
||||
dist.get_rank(), total_loss / num_seen_utts))
|
||||
return total_loss, num_seen_utts
|
||||
|
||||
def train(self):
|
||||
"""The training process control by step."""
|
||||
# !!!IMPORTANT!!!
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
# script_model = paddle.jit.to_static(self.model)
|
||||
# script_model_path = str(self.checkpoint_dir / 'init')
|
||||
# paddle.jit.save(script_model, script_model_path)
|
||||
|
||||
from_scratch = self.resume_or_scratch()
|
||||
if from_scratch:
|
||||
# save init model, i.e. 0 epoch
|
||||
self.save(tag='init')
|
||||
|
||||
self.lr_scheduler.step(self.iteration)
|
||||
if self.parallel:
|
||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||
|
||||
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
||||
while self.epoch < self.config.training.n_epoch:
|
||||
self.model.train()
|
||||
try:
|
||||
data_start_time = time.time()
|
||||
for batch_index, batch in enumerate(self.train_loader):
|
||||
dataload_time = time.time() - data_start_time
|
||||
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "batch : {}/{}, ".format(batch_index + 1,
|
||||
len(self.train_loader))
|
||||
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
||||
msg += "data time: {:>.3f}s, ".format(dataload_time)
|
||||
self.train_batch(batch_index, batch, msg)
|
||||
data_start_time = time.time()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
total_loss, num_seen_utts = self.valid()
|
||||
if dist.get_world_size() > 1:
|
||||
num_seen_utts = paddle.to_tensor(num_seen_utts)
|
||||
# the default operator in all_reduce function is sum.
|
||||
dist.all_reduce(num_seen_utts)
|
||||
total_loss = paddle.to_tensor(total_loss)
|
||||
dist.all_reduce(total_loss)
|
||||
cv_loss = total_loss / num_seen_utts
|
||||
cv_loss = float(cv_loss)
|
||||
else:
|
||||
cv_loss = total_loss / num_seen_utts
|
||||
|
||||
logger.info(
|
||||
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
||||
if self.visualizer:
|
||||
self.visualizer.add_scalars(
|
||||
'epoch', {'cv_loss': cv_loss,
|
||||
'lr': self.lr_scheduler()}, self.epoch)
|
||||
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
||||
self.new_epoch()
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config.clone()
|
||||
# train/valid dataset, return token ids
|
||||
self.train_loader = BatchDataLoader(
|
||||
json_file=config.data.train_manifest,
|
||||
train_mode=True,
|
||||
sortagrad=False,
|
||||
batch_size=config.collator.batch_size,
|
||||
maxlen_in=float('inf'),
|
||||
maxlen_out=float('inf'),
|
||||
minibatches=0,
|
||||
mini_batch_size=1,
|
||||
batch_count='auto',
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
preprocess_conf=config.collator.augmentation_config,
|
||||
n_iter_processes=config.collator.num_workers,
|
||||
subsampling_factor=1,
|
||||
num_encs=1)
|
||||
|
||||
self.valid_loader = BatchDataLoader(
|
||||
json_file=config.data.dev_manifest,
|
||||
train_mode=False,
|
||||
sortagrad=False,
|
||||
batch_size=config.collator.batch_size,
|
||||
maxlen_in=float('inf'),
|
||||
maxlen_out=float('inf'),
|
||||
minibatches=0,
|
||||
mini_batch_size=1,
|
||||
batch_count='auto',
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
preprocess_conf=None,
|
||||
n_iter_processes=1,
|
||||
subsampling_factor=1,
|
||||
num_encs=1)
|
||||
|
||||
# test dataset, return raw text
|
||||
self.test_loader = BatchDataLoader(
|
||||
json_file=config.data.test_manifest,
|
||||
train_mode=False,
|
||||
sortagrad=False,
|
||||
batch_size=config.collator.batch_size,
|
||||
maxlen_in=float('inf'),
|
||||
maxlen_out=float('inf'),
|
||||
minibatches=0,
|
||||
mini_batch_size=1,
|
||||
batch_count='auto',
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
preprocess_conf=None,
|
||||
n_iter_processes=1,
|
||||
subsampling_factor=1,
|
||||
num_encs=1)
|
||||
|
||||
self.align_loader = BatchDataLoader(
|
||||
json_file=config.data.test_manifest,
|
||||
train_mode=False,
|
||||
sortagrad=False,
|
||||
batch_size=config.collator.batch_size,
|
||||
maxlen_in=float('inf'),
|
||||
maxlen_out=float('inf'),
|
||||
minibatches=0,
|
||||
mini_batch_size=1,
|
||||
batch_count='auto',
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
preprocess_conf=None,
|
||||
n_iter_processes=1,
|
||||
subsampling_factor=1,
|
||||
num_encs=1)
|
||||
logger.info("Setup train/valid/test/align Dataloader!")
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
|
||||
# model
|
||||
model_conf = config.model
|
||||
model_conf.defrost()
|
||||
model_conf.input_dim = self.train_loader.feat_dim
|
||||
model_conf.output_dim = self.train_loader.vocab_size
|
||||
model_conf.freeze()
|
||||
model = U2Model.from_config(model_conf)
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
logger.info(f"{model}")
|
||||
layer_tools.print_params(model, logger.info)
|
||||
|
||||
# lr
|
||||
scheduler_conf = config.scheduler_conf
|
||||
scheduler_args = {
|
||||
"learning_rate": scheduler_conf.lr,
|
||||
"warmup_steps": scheduler_conf.warmup_steps,
|
||||
"gamma": scheduler_conf.lr_decay,
|
||||
"d_model": model_conf.encoder_conf.output_size,
|
||||
"verbose": False,
|
||||
}
|
||||
lr_scheduler = LRSchedulerFactory.from_args(config.scheduler,
|
||||
scheduler_args)
|
||||
|
||||
# opt
|
||||
def optimizer_args(
|
||||
config,
|
||||
parameters,
|
||||
lr_scheduler=None, ):
|
||||
optim_conf = config.optim_conf
|
||||
return {
|
||||
"grad_clip": optim_conf.global_grad_clip,
|
||||
"weight_decay": optim_conf.weight_decay,
|
||||
"learning_rate": lr_scheduler,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
|
||||
optimizer = OptimizerFactory.from_args(config.optim, optimzer_args)
|
||||
|
||||
self.model = model
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.optimizer = optimizer
|
||||
logger.info("Setup model/optimizer/lr_scheduler!")
|
||||
|
||||
|
||||
class U2Tester(U2Trainer):
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
# decoding config
|
||||
default = CfgNode(
|
||||
dict(
|
||||
alpha=2.5, # Coef of LM for beam search.
|
||||
beta=0.3, # Coef of WC for beam search.
|
||||
cutoff_prob=1.0, # Cutoff probability for pruning.
|
||||
cutoff_top_n=40, # Cutoff number for pruning.
|
||||
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
|
||||
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
|
||||
# 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
|
||||
num_proc_bsearch=8, # # of CPUs for beam search.
|
||||
beam_size=10, # Beam search width.
|
||||
batch_size=16, # decoding batch size
|
||||
ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming=False, # simulate streaming inference. Defaults to False.
|
||||
))
|
||||
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
|
||||
def id2token(self, texts, texts_len, text_feature):
|
||||
""" ord() id to chr() chr """
|
||||
trans = []
|
||||
for text, n in zip(texts, texts_len):
|
||||
n = n.numpy().item()
|
||||
ids = text[:n]
|
||||
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
|
||||
return trans
|
||||
|
||||
def compute_metrics(self,
|
||||
utts,
|
||||
audio,
|
||||
audio_len,
|
||||
texts,
|
||||
texts_len,
|
||||
fout=None):
|
||||
cfg = self.config.decoding
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
||||
|
||||
start_time = time.time()
|
||||
text_feature = TextFeaturizer(
|
||||
unit_type=self.config.collator.unit_type,
|
||||
vocab_filepath=self.config.collator.vocab_filepath,
|
||||
spm_model_prefix=self.config.collator.spm_model_prefix)
|
||||
target_transcripts = self.id2token(texts, texts_len, text_feature)
|
||||
result_transcripts = self.model.decode(
|
||||
audio,
|
||||
audio_len,
|
||||
text_feature=text_feature,
|
||||
decoding_method=cfg.decoding_method,
|
||||
lang_model_path=cfg.lang_model_path,
|
||||
beam_alpha=cfg.alpha,
|
||||
beam_beta=cfg.beta,
|
||||
beam_size=cfg.beam_size,
|
||||
cutoff_prob=cfg.cutoff_prob,
|
||||
cutoff_top_n=cfg.cutoff_top_n,
|
||||
num_processes=cfg.num_proc_bsearch,
|
||||
ctc_weight=cfg.ctc_weight,
|
||||
decoding_chunk_size=cfg.decoding_chunk_size,
|
||||
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
|
||||
simulate_streaming=cfg.simulate_streaming)
|
||||
decode_time = time.time() - start_time
|
||||
|
||||
for utt, target, result in zip(utts, target_transcripts,
|
||||
result_transcripts):
|
||||
errors, len_ref = errors_func(target, result)
|
||||
errors_sum += errors
|
||||
len_refs += len_ref
|
||||
num_ins += 1
|
||||
if fout:
|
||||
fout.write(utt + " " + result + "\n")
|
||||
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
||||
(target, result))
|
||||
logger.info("One example error rate [%s] = %f" %
|
||||
(cfg.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
return dict(
|
||||
errors_sum=errors_sum,
|
||||
len_refs=len_refs,
|
||||
num_ins=num_ins, # num examples
|
||||
error_rate=errors_sum / len_refs,
|
||||
error_rate_type=cfg.error_rate_type,
|
||||
num_frames=audio_len.sum().numpy().item(),
|
||||
decode_time=decode_time)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def test(self):
|
||||
assert self.args.result_file
|
||||
self.model.eval()
|
||||
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||
|
||||
stride_ms = self.config.collator.stride_ms
|
||||
error_rate_type = None
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
num_frames = 0.0
|
||||
num_time = 0.0
|
||||
with open(self.args.result_file, 'w') as fout:
|
||||
for i, batch in enumerate(self.test_loader):
|
||||
metrics = self.compute_metrics(*batch, fout=fout)
|
||||
num_frames += metrics['num_frames']
|
||||
num_time += metrics["decode_time"]
|
||||
errors_sum += metrics['errors_sum']
|
||||
len_refs += metrics['len_refs']
|
||||
num_ins += metrics['num_ins']
|
||||
error_rate_type = metrics['error_rate_type']
|
||||
rtf = num_time / (num_frames * stride_ms)
|
||||
logger.info(
|
||||
"RTF: %f, Error rate [%s] (%d/?) = %f" %
|
||||
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
|
||||
|
||||
rtf = num_time / (num_frames * stride_ms)
|
||||
msg = "Test: "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "RTF: {}, ".format(rtf)
|
||||
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
||||
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
||||
logger.info(msg)
|
||||
|
||||
# test meta results
|
||||
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
|
||||
err_type_str = "{}".format(error_rate_type)
|
||||
with open(err_meta_path, 'w') as f:
|
||||
data = json.dumps({
|
||||
"epoch":
|
||||
self.epoch,
|
||||
"step":
|
||||
self.iteration,
|
||||
"rtf":
|
||||
rtf,
|
||||
error_rate_type:
|
||||
errors_sum / len_refs,
|
||||
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
|
||||
"process_hour":
|
||||
num_time / 1000.0 / 3600.0,
|
||||
"num_examples":
|
||||
num_ins,
|
||||
"err_sum":
|
||||
errors_sum,
|
||||
"ref_len":
|
||||
len_refs,
|
||||
"decode_method":
|
||||
self.config.decoding.decoding_method,
|
||||
})
|
||||
f.write(data + '\n')
|
||||
|
||||
def run_test(self):
|
||||
self.resume_or_scratch()
|
||||
try:
|
||||
self.test()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
@paddle.no_grad()
|
||||
def align(self):
|
||||
if self.config.decoding.batch_size > 1:
|
||||
logger.fatal('alignment mode must be running with batch_size == 1')
|
||||
sys.exit(1)
|
||||
|
||||
# xxx.align
|
||||
assert self.args.result_file and self.args.result_file.endswith(
|
||||
'.align')
|
||||
|
||||
self.model.eval()
|
||||
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
||||
|
||||
stride_ms = self.config.collater.stride_ms
|
||||
token_dict = self.args.char_list
|
||||
|
||||
with open(self.args.result_file, 'w') as fout:
|
||||
# one example in batch
|
||||
for i, batch in enumerate(self.align_loader):
|
||||
key, feat, feats_length, target, target_length = batch
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_mask = self.model._forward_encoder(
|
||||
feat, feats_length) # (B, maxlen, encoder_dim)
|
||||
maxlen = encoder_out.size(1)
|
||||
ctc_probs = self.model.ctc.log_softmax(
|
||||
encoder_out) # (1, maxlen, vocab_size)
|
||||
|
||||
# 2. alignment
|
||||
ctc_probs = ctc_probs.squeeze(0)
|
||||
target = target.squeeze(0)
|
||||
alignment = ctc_utils.forced_align(ctc_probs, target)
|
||||
logger.info("align ids", key[0], alignment)
|
||||
fout.write('{} {}\n'.format(key[0], alignment))
|
||||
|
||||
# 3. gen praat
|
||||
# segment alignment
|
||||
align_segs = text_grid.segment_alignment(alignment)
|
||||
logger.info("align tokens", key[0], align_segs)
|
||||
# IntervalTier, List["start end token\n"]
|
||||
subsample = utility.get_subsample(self.config)
|
||||
tierformat = text_grid.align_to_tierformat(
|
||||
align_segs, subsample, token_dict)
|
||||
# write tier
|
||||
align_output_path = os.path.join(
|
||||
os.path.dirname(self.args.result_file), "align")
|
||||
tier_path = os.path.join(align_output_path, key[0] + ".tier")
|
||||
with open(tier_path, 'w') as f:
|
||||
f.writelines(tierformat)
|
||||
# write textgrid
|
||||
textgrid_path = os.path.join(align_output_path,
|
||||
key[0] + ".TextGrid")
|
||||
second_per_frame = 1. / (1000. /
|
||||
stride_ms) # 25ms window, 10ms stride
|
||||
second_per_example = (
|
||||
len(alignment) + 1) * subsample * second_per_frame
|
||||
text_grid.generate_textgrid(
|
||||
maxtime=second_per_example,
|
||||
intervals=tierformat,
|
||||
output=textgrid_path)
|
||||
|
||||
def run_align(self):
|
||||
self.resume_or_scratch()
|
||||
try:
|
||||
self.align()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
def load_inferspec(self):
|
||||
"""infer model and input spec.
|
||||
|
||||
Returns:
|
||||
nn.Layer: inference model
|
||||
List[paddle.static.InputSpec]: input spec.
|
||||
"""
|
||||
from deepspeech.models.u2 import U2InferModel
|
||||
infer_model = U2InferModel.from_pretrained(self.test_loader,
|
||||
self.config.model.clone(),
|
||||
self.args.checkpoint_path)
|
||||
feat_dim = self.test_loader.feat_dim
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(shape=[1, None, feat_dim],
|
||||
dtype='float32'), # audio, [B,T,D]
|
||||
paddle.static.InputSpec(shape=[1],
|
||||
dtype='int64'), # audio_length, [B]
|
||||
]
|
||||
return infer_model, input_spec
|
||||
|
||||
def export(self):
|
||||
infer_model, input_spec = self.load_inferspec()
|
||||
assert isinstance(input_spec, list), type(input_spec)
|
||||
infer_model.eval()
|
||||
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
|
||||
logger.info(f"Export code: {static_model.forward.code}")
|
||||
paddle.jit.save(static_model, self.args.export_path)
|
||||
|
||||
def run_export(self):
|
||||
try:
|
||||
self.export()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
def setup_dict(self):
|
||||
# load dictionary for debug log
|
||||
self.args.char_list = load_dict(self.args.dict_path,
|
||||
"maskctc" in self.args.model_name)
|
||||
|
||||
def setup(self):
|
||||
"""Setup the experiment.
|
||||
"""
|
||||
paddle.set_device(self.args.device)
|
||||
|
||||
self.setup_output_dir()
|
||||
self.setup_checkpointer()
|
||||
|
||||
self.setup_dataloader()
|
||||
self.setup_model()
|
||||
|
||||
self.setup_dict()
|
||||
|
||||
self.iteration = 0
|
||||
self.epoch = 0
|
||||
|
||||
def setup_output_dir(self):
|
||||
"""Create a directory used for output.
|
||||
"""
|
||||
# output dir
|
||||
if self.args.output:
|
||||
output_dir = Path(self.args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_dir = Path(
|
||||
self.args.checkpoint_path).expanduser().parent.parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.output_dir = output_dir
|
@ -0,0 +1,469 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["make_batchset"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def batchfy_by_seq(
|
||||
sorted_data,
|
||||
batch_size,
|
||||
max_length_in,
|
||||
max_length_out,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
iaxis=0,
|
||||
okey="output",
|
||||
oaxis=0, ):
|
||||
"""Make batch set from json dictionary
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
||||
:param int batch_size: batch size
|
||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
||||
:param int min_batch_size: mininum batch size (for multi-gpu)
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
:param str ikey: key to access input
|
||||
(for ASR ikey="input", for TTS, MT ikey="output".)
|
||||
:param int iaxis: dimension to access input
|
||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
||||
:param str okey: key to access output
|
||||
(for ASR, MT okey="output". for TTS okey="input".)
|
||||
:param int oaxis: dimension to access output
|
||||
(for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
|
||||
:return: List[List[Tuple[str, dict]]] list of batches
|
||||
"""
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"Invalid batch_size={batch_size}")
|
||||
|
||||
# check #utts is more than min_batch_size
|
||||
if len(sorted_data) < min_batch_size:
|
||||
raise ValueError(
|
||||
f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
|
||||
)
|
||||
|
||||
# make list of minibatches
|
||||
minibatches = []
|
||||
start = 0
|
||||
while True:
|
||||
_, info = sorted_data[start]
|
||||
ilen = int(info[ikey][iaxis]["shape"][0])
|
||||
olen = (int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else
|
||||
max(map(lambda x: int(x["shape"][0]), info[okey])))
|
||||
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
|
||||
# change batchsize depending on the input and output length
|
||||
# if ilen = 1000 and max_length_in = 800
|
||||
# then b = batchsize / 2
|
||||
# and max(min_batches, .) avoids batchsize = 0
|
||||
bs = max(min_batch_size, int(batch_size / (1 + factor)))
|
||||
end = min(len(sorted_data), start + bs)
|
||||
minibatch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
minibatch.reverse()
|
||||
|
||||
# check each batch is more than minimum batchsize
|
||||
if len(minibatch) < min_batch_size:
|
||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
||||
additional_minibatch = [
|
||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
||||
]
|
||||
if shortest_first:
|
||||
additional_minibatch.reverse()
|
||||
minibatch.extend(additional_minibatch)
|
||||
minibatches.append(minibatch)
|
||||
|
||||
if end == len(sorted_data):
|
||||
break
|
||||
start = end
|
||||
|
||||
# batch: List[List[Tuple[str, dict]]]
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_by_bin(
|
||||
sorted_data,
|
||||
batch_bins,
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
okey="output", ):
|
||||
"""Make variably sized batch set, which maximizes
|
||||
|
||||
the number of bins up to `batch_bins`.
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
||||
:param int batch_bins: Maximum frames of a batch
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param int test: Return only every `test` batches
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
|
||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
||||
|
||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
||||
"""
|
||||
if batch_bins <= 0:
|
||||
raise ValueError(f"invalid batch_bins={batch_bins}")
|
||||
length = len(sorted_data)
|
||||
idim = int(sorted_data[0][1][ikey][0]["shape"][1])
|
||||
odim = int(sorted_data[0][1][okey][0]["shape"][1])
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
minibatches = []
|
||||
start = 0
|
||||
n = 0
|
||||
while True:
|
||||
# Dynamic batch size depending on size of samples
|
||||
b = 0
|
||||
next_size = 0
|
||||
max_olen = 0
|
||||
while next_size < batch_bins and (start + b) < length:
|
||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
|
||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
|
||||
if olen > max_olen:
|
||||
max_olen = olen
|
||||
next_size = (max_olen + ilen) * (b + 1)
|
||||
if next_size <= batch_bins:
|
||||
b += 1
|
||||
elif next_size == 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in batch_bins ({batch_bins}): "
|
||||
f"Please increase the value")
|
||||
end = min(length, start + max(min_batch_size, b))
|
||||
batch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
batch.reverse()
|
||||
minibatches.append(batch)
|
||||
# Check for min_batch_size and fixes the batches if needed
|
||||
i = -1
|
||||
while len(minibatches[i]) < min_batch_size:
|
||||
missing = min_batch_size - len(minibatches[i])
|
||||
if -i == len(minibatches):
|
||||
minibatches[i + 1].extend(minibatches[i])
|
||||
minibatches = minibatches[1:]
|
||||
break
|
||||
else:
|
||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
||||
i -= 1
|
||||
if end == length:
|
||||
break
|
||||
start = end
|
||||
n += 1
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
lengths = [len(x) for x in minibatches]
|
||||
logger.info(
|
||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
||||
+ " to " + str(max(lengths)) + " samples " + "(avg " + str(
|
||||
int(np.mean(lengths))) + " samples).")
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_by_frame(
|
||||
sorted_data,
|
||||
max_frames_in,
|
||||
max_frames_out,
|
||||
max_frames_inout,
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
ikey="input",
|
||||
okey="output", ):
|
||||
"""Make variable batch set, which maximizes the number of frames to max_batch_frame.
|
||||
|
||||
:param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json
|
||||
:param int max_frames_in: Maximum input frames of a batch
|
||||
:param int max_frames_out: Maximum output frames of a batch
|
||||
:param int max_frames_inout: Maximum input+output frames of a batch
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param int test: Return only every `test` batches
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
|
||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
||||
|
||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
||||
"""
|
||||
if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
|
||||
raise ValueError(
|
||||
"At least, one of `--batch-frames-in`, `--batch-frames-out` or "
|
||||
"`--batch-frames-inout` should be > 0")
|
||||
length = len(sorted_data)
|
||||
minibatches = []
|
||||
start = 0
|
||||
end = 0
|
||||
while end != length:
|
||||
# Dynamic batch size depending on size of samples
|
||||
b = 0
|
||||
max_olen = 0
|
||||
max_ilen = 0
|
||||
while (start + b) < length:
|
||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
|
||||
if ilen > max_frames_in and max_frames_in != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
|
||||
f"Please increase the value")
|
||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
|
||||
if olen > max_frames_out and max_frames_out != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
|
||||
f"Please increase the value")
|
||||
if ilen + olen > max_frames_inout and max_frames_inout != 0:
|
||||
raise ValueError(
|
||||
f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
|
||||
f"Please increase the value")
|
||||
max_olen = max(max_olen, olen)
|
||||
max_ilen = max(max_ilen, ilen)
|
||||
in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
|
||||
out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
|
||||
inout_ok = (max_ilen + max_olen) * (
|
||||
b + 1) <= max_frames_inout or max_frames_inout == 0
|
||||
if in_ok and out_ok and inout_ok:
|
||||
# add more seq in the minibatch
|
||||
b += 1
|
||||
else:
|
||||
# no more seq in the minibatch
|
||||
break
|
||||
end = min(length, start + b)
|
||||
batch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
batch.reverse()
|
||||
minibatches.append(batch)
|
||||
# Check for min_batch_size and fixes the batches if needed
|
||||
i = -1
|
||||
while len(minibatches[i]) < min_batch_size:
|
||||
missing = min_batch_size - len(minibatches[i])
|
||||
if -i == len(minibatches):
|
||||
minibatches[i + 1].extend(minibatches[i])
|
||||
minibatches = minibatches[1:]
|
||||
break
|
||||
else:
|
||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
||||
i -= 1
|
||||
start = end
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
lengths = [len(x) for x in minibatches]
|
||||
logger.info(
|
||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
||||
+ " to " + str(max(lengths)) + " samples" + "(avg " + str(
|
||||
int(np.mean(lengths))) + " samples).")
|
||||
|
||||
return minibatches
|
||||
|
||||
|
||||
def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,
|
||||
shortest_first):
|
||||
import random
|
||||
|
||||
logger.info("use shuffled batch.")
|
||||
sorted_data = random.sample(data.items(), len(data.items()))
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
# make list of minibatches
|
||||
minibatches = []
|
||||
start = 0
|
||||
while True:
|
||||
end = min(len(sorted_data), start + batch_size)
|
||||
# check each batch is more than minimum batchsize
|
||||
minibatch = sorted_data[start:end]
|
||||
if shortest_first:
|
||||
minibatch.reverse()
|
||||
if len(minibatch) < min_batch_size:
|
||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
||||
additional_minibatch = [
|
||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
||||
]
|
||||
if shortest_first:
|
||||
additional_minibatch.reverse()
|
||||
minibatch.extend(additional_minibatch)
|
||||
minibatches.append(minibatch)
|
||||
if end == len(sorted_data):
|
||||
break
|
||||
start = end
|
||||
|
||||
# for debugging
|
||||
if num_batches > 0:
|
||||
minibatches = minibatches[:num_batches]
|
||||
logger.info("# minibatches: " + str(len(minibatches)))
|
||||
return minibatches
|
||||
|
||||
|
||||
BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
|
||||
BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
|
||||
|
||||
|
||||
def make_batchset(
|
||||
data,
|
||||
batch_size=0,
|
||||
max_length_in=float("inf"),
|
||||
max_length_out=float("inf"),
|
||||
num_batches=0,
|
||||
min_batch_size=1,
|
||||
shortest_first=False,
|
||||
batch_sort_key="input",
|
||||
count="auto",
|
||||
batch_bins=0,
|
||||
batch_frames_in=0,
|
||||
batch_frames_out=0,
|
||||
batch_frames_inout=0,
|
||||
iaxis=0,
|
||||
oaxis=0, ):
|
||||
"""Make batch set from json dictionary
|
||||
|
||||
if utts have "category" value,
|
||||
|
||||
>>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
|
||||
... {'category': 'B', 'input': ..., 'utt':'utt2'},
|
||||
... {'category': 'B', 'input': ..., 'utt':'utt3'},
|
||||
... {'category': 'A', 'input': ..., 'utt':'utt4'}]
|
||||
>>> make_batchset(data, batchsize=2, ...)
|
||||
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
|
||||
|
||||
Note that if any utts doesn't have "category",
|
||||
perform as same as batchfy_by_{count}
|
||||
|
||||
:param List[Dict[str, Any]] data: dictionary loaded from data.json
|
||||
:param int batch_size: maximum number of sequences in a minibatch.
|
||||
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
|
||||
:param int batch_frames_in: maximum number of input frames in a minibatch.
|
||||
:param int batch_frames_out: maximum number of output frames in a minibatch.
|
||||
:param int batch_frames_out: maximum number of input+output frames in a minibatch.
|
||||
:param str count: strategy to count maximum size of batch.
|
||||
For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES
|
||||
|
||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
||||
:param int num_batches: # number of batches to use (for debug)
|
||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
||||
:param bool shortest_first: Sort from batch with shortest samples
|
||||
to longest if true, otherwise reverse
|
||||
:param str batch_sort_key: how to sort data before creating minibatches
|
||||
["input", "output", "shuffle"]
|
||||
:param bool swap_io: if True, use "input" as output and "output"
|
||||
as input in `data` dict
|
||||
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
|
||||
as input in `data` dict
|
||||
:param int iaxis: dimension to access input
|
||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
||||
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
|
||||
reserved for future research, -1 means all axis.)
|
||||
:return: List[List[Tuple[str, dict]]] list of batches
|
||||
"""
|
||||
# check args
|
||||
if count not in BATCH_COUNT_CHOICES:
|
||||
raise ValueError(
|
||||
f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
|
||||
if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
|
||||
raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be "
|
||||
f"one of {BATCH_SORT_KEY_CHOICES}")
|
||||
|
||||
ikey = "input"
|
||||
okey = "output"
|
||||
batch_sort_axis = 0 # index of list
|
||||
if count == "auto":
|
||||
if batch_size != 0:
|
||||
count = "seq"
|
||||
elif batch_bins != 0:
|
||||
count = "bin"
|
||||
elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
|
||||
count = "frame"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
|
||||
)
|
||||
logger.info(f"count is auto detected as {count}")
|
||||
|
||||
if count != "seq" and batch_sort_key == "shuffle":
|
||||
raise ValueError(
|
||||
"batch_sort_key=shuffle is only available if batch_count=seq")
|
||||
|
||||
category2data = {} # Dict[str, dict]
|
||||
for v in data:
|
||||
k = v['utt']
|
||||
category2data.setdefault(v.get("category"), {})[k] = v
|
||||
|
||||
batches_list = [] # List[List[List[Tuple[str, dict]]]]
|
||||
for d in category2data.values():
|
||||
if batch_sort_key == "shuffle":
|
||||
batches = batchfy_shuffle(d, batch_size, min_batch_size,
|
||||
num_batches, shortest_first)
|
||||
batches_list.append(batches)
|
||||
continue
|
||||
|
||||
# sort it by input lengths (long to short)
|
||||
sorted_data = sorted(
|
||||
d.items(),
|
||||
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
|
||||
reverse=not shortest_first, )
|
||||
logger.info("# utts: " + str(len(sorted_data)))
|
||||
|
||||
if count == "seq":
|
||||
batches = batchfy_by_seq(
|
||||
sorted_data,
|
||||
batch_size=batch_size,
|
||||
max_length_in=max_length_in,
|
||||
max_length_out=max_length_out,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
iaxis=iaxis,
|
||||
okey=okey,
|
||||
oaxis=oaxis, )
|
||||
if count == "bin":
|
||||
batches = batchfy_by_bin(
|
||||
sorted_data,
|
||||
batch_bins=batch_bins,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
okey=okey, )
|
||||
if count == "frame":
|
||||
batches = batchfy_by_frame(
|
||||
sorted_data,
|
||||
max_frames_in=batch_frames_in,
|
||||
max_frames_out=batch_frames_out,
|
||||
max_frames_inout=batch_frames_inout,
|
||||
min_batch_size=min_batch_size,
|
||||
shortest_first=shortest_first,
|
||||
ikey=ikey,
|
||||
okey=okey, )
|
||||
batches_list.append(batches)
|
||||
|
||||
if len(batches_list) == 1:
|
||||
batches = batches_list[0]
|
||||
else:
|
||||
# Concat list. This way is faster than "sum(batch_list, [])"
|
||||
batches = list(itertools.chain(*batches_list))
|
||||
|
||||
# for debugging
|
||||
if num_batches > 0:
|
||||
batches = batches[:num_batches]
|
||||
logger.info("# minibatches: " + str(len(batches)))
|
||||
|
||||
# batch: List[List[Tuple[str, dict]]]
|
||||
return batches
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.io.utility import pad_list
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["CustomConverter"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class CustomConverter():
|
||||
"""Custom batch converter.
|
||||
|
||||
Args:
|
||||
subsampling_factor (int): The subsampling factor.
|
||||
dtype (np.dtype): Data type to convert.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subsampling_factor=1, dtype=np.float32):
|
||||
"""Construct a CustomConverter object."""
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.ignore_id = -1
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Transform a batch and send it to a device.
|
||||
|
||||
Args:
|
||||
batch (list): The batch to transform.
|
||||
|
||||
Returns:
|
||||
tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)
|
||||
|
||||
"""
|
||||
# batch should be located in list
|
||||
assert len(batch) == 1
|
||||
(xs, ys), utts = batch[0]
|
||||
assert xs[0] is not None, "please check Reader and Augmentation impl."
|
||||
|
||||
# perform subsampling
|
||||
if self.subsampling_factor > 1:
|
||||
xs = [x[::self.subsampling_factor, :] for x in xs]
|
||||
|
||||
# get batch of lengths of input sequences
|
||||
ilens = np.array([x.shape[0] for x in xs])
|
||||
|
||||
# perform padding and convert to tensor
|
||||
# currently only support real number
|
||||
if xs[0].dtype.kind == "c":
|
||||
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
|
||||
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
|
||||
# Note(kamo):
|
||||
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
|
||||
# Don't create ComplexTensor and give it E2E here
|
||||
# because torch.nn.DataParellel can't handle it.
|
||||
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
|
||||
else:
|
||||
xs_pad = pad_list(xs, 0).astype(self.dtype)
|
||||
|
||||
# NOTE: this is for multi-output (e.g., speech translation)
|
||||
ys_pad = pad_list(
|
||||
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
|
||||
self.ignore_id)
|
||||
|
||||
olens = np.array(
|
||||
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
|
||||
return utts, xs_pad, ilens, ys_pad, olens
|
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Text
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import DataLoader
|
||||
|
||||
from deepspeech.frontend.utility import read_manifest
|
||||
from deepspeech.io.batchfy import make_batchset
|
||||
from deepspeech.io.converter import CustomConverter
|
||||
from deepspeech.io.dataset import TransformDataset
|
||||
from deepspeech.io.reader import LoadInputsAndTargets
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["BatchDataLoader"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
|
||||
mode: Text="asr",
|
||||
iaxis=0,
|
||||
oaxis=0):
|
||||
if mode == 'asr':
|
||||
feat_dim = data_json[0]['input'][oaxis]['shape'][1]
|
||||
vocab_size = data_json[0]['output'][oaxis]['shape'][1]
|
||||
else:
|
||||
raise ValueError(f"{mode} mode not support!")
|
||||
return feat_dim, vocab_size
|
||||
|
||||
|
||||
class BatchDataLoader():
|
||||
def __init__(self,
|
||||
json_file: str,
|
||||
train_mode: bool,
|
||||
sortagrad: bool=False,
|
||||
batch_size: int=0,
|
||||
maxlen_in: float=float('inf'),
|
||||
maxlen_out: float=float('inf'),
|
||||
minibatches: int=0,
|
||||
mini_batch_size: int=1,
|
||||
batch_count: str='auto',
|
||||
batch_bins: int=0,
|
||||
batch_frames_in: int=0,
|
||||
batch_frames_out: int=0,
|
||||
batch_frames_inout: int=0,
|
||||
preprocess_conf=None,
|
||||
n_iter_processes: int=1,
|
||||
subsampling_factor: int=1,
|
||||
num_encs: int=1):
|
||||
self.json_file = json_file
|
||||
self.train_mode = train_mode
|
||||
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
||||
self.batch_size = batch_size
|
||||
self.maxlen_in = maxlen_in
|
||||
self.maxlen_out = maxlen_out
|
||||
self.batch_count = batch_count
|
||||
self.batch_bins = batch_bins
|
||||
self.batch_frames_in = batch_frames_in
|
||||
self.batch_frames_out = batch_frames_out
|
||||
self.batch_frames_inout = batch_frames_inout
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.num_encs = num_encs
|
||||
self.preprocess_conf = preprocess_conf
|
||||
self.n_iter_processes = n_iter_processes
|
||||
|
||||
# read json data
|
||||
self.data_json = read_manifest(json_file)
|
||||
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
|
||||
self.data_json, mode='asr')
|
||||
|
||||
# make minibatch list (variable length)
|
||||
self.minibaches = make_batchset(
|
||||
self.data_json,
|
||||
batch_size,
|
||||
maxlen_in,
|
||||
maxlen_out,
|
||||
minibatches, # for debug
|
||||
min_batch_size=mini_batch_size,
|
||||
shortest_first=self.use_sortagrad,
|
||||
count=batch_count,
|
||||
batch_bins=batch_bins,
|
||||
batch_frames_in=batch_frames_in,
|
||||
batch_frames_out=batch_frames_out,
|
||||
batch_frames_inout=batch_frames_inout,
|
||||
iaxis=0,
|
||||
oaxis=0, )
|
||||
|
||||
# data reader
|
||||
self.reader = LoadInputsAndTargets(
|
||||
mode="asr",
|
||||
load_output=True,
|
||||
preprocess_conf=preprocess_conf,
|
||||
preprocess_args={"train":
|
||||
train_mode}, # Switch the mode of preprocessing
|
||||
)
|
||||
|
||||
# Setup a converter
|
||||
if num_encs == 1:
|
||||
self.converter = CustomConverter(
|
||||
subsampling_factor=subsampling_factor, dtype=np.float32)
|
||||
else:
|
||||
assert NotImplementedError("not impl CustomConverterMulEnc.")
|
||||
|
||||
# hack to make batchsize argument as 1
|
||||
# actual bathsize is included in a list
|
||||
# default collate function converts numpy array to pytorch tensor
|
||||
# we used an empty collate function instead which returns list
|
||||
self.dataset = TransformDataset(
|
||||
self.minibaches,
|
||||
lambda data: self.converter([self.reader(data, return_uttid=True)]))
|
||||
self.dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_size=1,
|
||||
shuffle=not self.use_sortagrad if train_mode else False,
|
||||
collate_fn=lambda x: x[0],
|
||||
num_workers=n_iter_processes, )
|
||||
|
||||
def __repr__(self):
|
||||
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
|
||||
echo += f"train_mode: {self.train_mode}, "
|
||||
echo += f"sortagrad: {self.use_sortagrad}, "
|
||||
echo += f"batch_size: {self.batch_size}, "
|
||||
echo += f"maxlen_in: {self.maxlen_in}, "
|
||||
echo += f"maxlen_out: {self.maxlen_out}, "
|
||||
echo += f"batch_count: {self.batch_count}, "
|
||||
echo += f"batch_bins: {self.batch_bins}, "
|
||||
echo += f"batch_frames_in: {self.batch_frames_in}, "
|
||||
echo += f"batch_frames_out: {self.batch_frames_out}, "
|
||||
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
|
||||
echo += f"subsampling_factor: {self.subsampling_factor}, "
|
||||
echo += f"num_encs: {self.num_encs}, "
|
||||
echo += f"num_workers: {self.n_iter_processes}, "
|
||||
echo += f"file: {self.json_file}"
|
||||
return echo
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def __iter__(self):
|
||||
return self.dataloader.__iter__()
|
||||
|
||||
def __call__(self):
|
||||
return self.__iter__()
|
@ -0,0 +1,410 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
|
||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["LoadInputsAndTargets"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class LoadInputsAndTargets():
|
||||
"""Create a mini-batch from a list of dicts
|
||||
|
||||
>>> batch = [('utt1',
|
||||
... dict(input=[dict(feat='some.ark:123',
|
||||
... filetype='mat',
|
||||
... name='input1',
|
||||
... shape=[100, 80])],
|
||||
... output=[dict(tokenid='1 2 3 4',
|
||||
... name='target1',
|
||||
... shape=[4, 31])]]))
|
||||
>>> l = LoadInputsAndTargets()
|
||||
>>> feat, target = l(batch)
|
||||
|
||||
:param: str mode: Specify the task mode, "asr" or "tts"
|
||||
:param: str preprocess_conf: The path of a json file for pre-processing
|
||||
:param: bool load_input: If False, not to load the input data
|
||||
:param: bool load_output: If False, not to load the output data
|
||||
:param: bool sort_in_input_length: Sort the mini-batch in descending order
|
||||
of the input length
|
||||
:param: bool use_speaker_embedding: Used for tts mode only
|
||||
:param: bool use_second_target: Used for tts mode only
|
||||
:param: dict preprocess_args: Set some optional arguments for preprocessing
|
||||
:param: Optional[dict] preprocess_args: Used for tts mode only
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode="asr",
|
||||
preprocess_conf=None,
|
||||
load_input=True,
|
||||
load_output=True,
|
||||
sort_in_input_length=True,
|
||||
preprocess_args=None,
|
||||
keep_all_data_on_mem=False, ):
|
||||
self._loaders = {}
|
||||
|
||||
if mode not in ["asr"]:
|
||||
raise ValueError("Only asr are allowed: mode={}".format(mode))
|
||||
|
||||
if preprocess_conf is not None:
|
||||
with open(preprocess_conf, 'r') as fin:
|
||||
self.preprocessing = AugmentationPipeline(fin.read())
|
||||
logger.warning(
|
||||
"[Experimental feature] Some preprocessing will be done "
|
||||
"for the mini-batch creation using {}".format(
|
||||
self.preprocessing))
|
||||
else:
|
||||
# If conf doesn't exist, this function don't touch anything.
|
||||
self.preprocessing = None
|
||||
|
||||
self.mode = mode
|
||||
self.load_output = load_output
|
||||
self.load_input = load_input
|
||||
self.sort_in_input_length = sort_in_input_length
|
||||
if preprocess_args is None:
|
||||
self.preprocess_args = {}
|
||||
else:
|
||||
assert isinstance(preprocess_args, dict), type(preprocess_args)
|
||||
self.preprocess_args = dict(preprocess_args)
|
||||
|
||||
self.keep_all_data_on_mem = keep_all_data_on_mem
|
||||
|
||||
def __call__(self, batch, return_uttid=False):
|
||||
"""Function to load inputs and targets from list of dicts
|
||||
|
||||
:param List[Tuple[str, dict]] batch: list of dict which is subset of
|
||||
loaded data.json
|
||||
:param bool return_uttid: return utterance ID information for visualization
|
||||
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
|
||||
:return: list of input feature sequences
|
||||
[(T_1, D), (T_2, D), ..., (T_B, D)]
|
||||
:rtype: list of float ndarray
|
||||
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
|
||||
:rtype: list of int ndarray
|
||||
|
||||
"""
|
||||
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
||||
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
||||
uttid_list = [] # List[str]
|
||||
|
||||
for uttid, info in batch:
|
||||
uttid_list.append(uttid)
|
||||
|
||||
if self.load_input:
|
||||
# Note(kamo): This for-loop is for multiple inputs
|
||||
for idx, inp in enumerate(info["input"]):
|
||||
# {"input":
|
||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# "name": "input1", ...}], ...}
|
||||
x = self._get_from_loader(
|
||||
filepath=inp["feat"],
|
||||
filetype=inp.get("filetype", "mat"))
|
||||
x_feats_dict.setdefault(inp["name"], []).append(x)
|
||||
|
||||
if self.load_output:
|
||||
for idx, inp in enumerate(info["output"]):
|
||||
if "tokenid" in inp:
|
||||
# ======= Legacy format for output =======
|
||||
# {"output": [{"tokenid": "1 2 3 4"}])
|
||||
x = np.fromiter(
|
||||
map(int, inp["tokenid"].split()), dtype=np.int64)
|
||||
else:
|
||||
# ======= New format =======
|
||||
# {"input":
|
||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# "name": "target1", ...}], ...}
|
||||
x = self._get_from_loader(
|
||||
filepath=inp["feat"],
|
||||
filetype=inp.get("filetype", "mat"))
|
||||
|
||||
y_feats_dict.setdefault(inp["name"], []).append(x)
|
||||
|
||||
if self.mode == "asr":
|
||||
return_batch, uttid_list = self._create_batch_asr(
|
||||
x_feats_dict, y_feats_dict, uttid_list)
|
||||
else:
|
||||
raise NotImplementedError(self.mode)
|
||||
|
||||
if self.preprocessing is not None:
|
||||
# Apply pre-processing all input features
|
||||
for x_name in return_batch.keys():
|
||||
if x_name.startswith("input"):
|
||||
return_batch[x_name] = self.preprocessing(
|
||||
return_batch[x_name], uttid_list,
|
||||
**self.preprocess_args)
|
||||
|
||||
if return_uttid:
|
||||
return tuple(return_batch.values()), uttid_list
|
||||
|
||||
# Doesn't return the names now.
|
||||
return tuple(return_batch.values())
|
||||
|
||||
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
|
||||
"""Create a OrderedDict for the mini-batch
|
||||
|
||||
:param OrderedDict x_feats_dict:
|
||||
e.g. {"input1": [ndarray, ndarray, ...],
|
||||
"input2": [ndarray, ndarray, ...]}
|
||||
:param OrderedDict y_feats_dict:
|
||||
e.g. {"target1": [ndarray, ndarray, ...],
|
||||
"target2": [ndarray, ndarray, ...]}
|
||||
:param: List[str] uttid_list:
|
||||
Give uttid_list to sort in the same order as the mini-batch
|
||||
:return: batch, uttid_list
|
||||
:rtype: Tuple[OrderedDict, List[str]]
|
||||
"""
|
||||
# handle single-input and multi-input (paralell) asr mode
|
||||
xs = list(x_feats_dict.values())
|
||||
|
||||
if self.load_output:
|
||||
ys = list(y_feats_dict.values())
|
||||
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
|
||||
|
||||
# get index of non-zero length samples
|
||||
nonzero_idx = list(
|
||||
filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
|
||||
for n in range(1, len(y_feats_dict)):
|
||||
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
|
||||
else:
|
||||
# Note(kamo): Be careful not to make nonzero_idx to a generator
|
||||
nonzero_idx = list(range(len(xs[0])))
|
||||
|
||||
if self.sort_in_input_length:
|
||||
# sort in input lengths based on the first input
|
||||
nonzero_sorted_idx = sorted(
|
||||
nonzero_idx, key=lambda i: -len(xs[0][i]))
|
||||
else:
|
||||
nonzero_sorted_idx = nonzero_idx
|
||||
|
||||
if len(nonzero_sorted_idx) != len(xs[0]):
|
||||
logger.warning(
|
||||
"Target sequences include empty tokenid (batch {} -> {}).".
|
||||
format(len(xs[0]), len(nonzero_sorted_idx)))
|
||||
|
||||
# remove zero-length samples
|
||||
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
|
||||
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
|
||||
|
||||
x_names = list(x_feats_dict.keys())
|
||||
if self.load_output:
|
||||
ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
|
||||
y_names = list(y_feats_dict.keys())
|
||||
|
||||
# Keeping x_name and y_name, e.g. input1, for future extension
|
||||
return_batch = OrderedDict([
|
||||
* [(x_name, x) for x_name, x in zip(x_names, xs)],
|
||||
* [(y_name, y) for y_name, y in zip(y_names, ys)],
|
||||
])
|
||||
else:
|
||||
return_batch = OrderedDict(
|
||||
[(x_name, x) for x_name, x in zip(x_names, xs)])
|
||||
return return_batch, uttid_list
|
||||
|
||||
def _get_from_loader(self, filepath, filetype):
|
||||
"""Return ndarray
|
||||
|
||||
In order to make the fds to be opened only at the first referring,
|
||||
the loader are stored in self._loaders
|
||||
|
||||
>>> ndarray = loader.get_from_loader(
|
||||
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
|
||||
|
||||
:param: str filepath:
|
||||
:param: str filetype:
|
||||
:return:
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
if filetype == "hdf5":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "hdf5",
|
||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = h5py.File(filepath, "r")
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key][()]
|
||||
elif filetype == "sound.hdf5":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
||||
# "filetype": "sound.hdf5",
|
||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = SoundHDF5File(filepath, "r", dtype="int16")
|
||||
self._loaders[filepath] = loader
|
||||
array, rate = loader[key]
|
||||
return array
|
||||
elif filetype == "sound":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.wav",
|
||||
# "filetype": "sound"},
|
||||
# Assume PCM16
|
||||
if not self.keep_all_data_on_mem:
|
||||
array, _ = soundfile.read(filepath, dtype="int16")
|
||||
return array
|
||||
if filepath not in self._loaders:
|
||||
array, _ = soundfile.read(filepath, dtype="int16")
|
||||
self._loaders[filepath] = array
|
||||
return self._loaders[filepath]
|
||||
elif filetype == "npz":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
|
||||
# "filetype": "npz",
|
||||
filepath, key = filepath.split(":", 1)
|
||||
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = np.load(filepath)
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key]
|
||||
elif filetype == "npy":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.npy",
|
||||
# "filetype": "npy"},
|
||||
if not self.keep_all_data_on_mem:
|
||||
return np.load(filepath)
|
||||
if filepath not in self._loaders:
|
||||
self._loaders[filepath] = np.load(filepath)
|
||||
return self._loaders[filepath]
|
||||
elif filetype in ["mat", "vec"]:
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.ark:123",
|
||||
# "filetype": "mat"}]},
|
||||
# In this case, "123" indicates the starting points of the matrix
|
||||
# load_mat can load both matrix and vector
|
||||
if not self.keep_all_data_on_mem:
|
||||
return kaldiio.load_mat(filepath)
|
||||
if filepath not in self._loaders:
|
||||
self._loaders[filepath] = kaldiio.load_mat(filepath)
|
||||
return self._loaders[filepath]
|
||||
elif filetype == "scp":
|
||||
# e.g.
|
||||
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
|
||||
# "filetype": "scp",
|
||||
filepath, key = filepath.split(":", 1)
|
||||
loader = self._loaders.get(filepath)
|
||||
if loader is None:
|
||||
# To avoid disk access, create loader only for the first time
|
||||
loader = kaldiio.load_scp(filepath)
|
||||
self._loaders[filepath] = loader
|
||||
return loader[key]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supported: loader_type={}".format(filetype))
|
||||
|
||||
|
||||
class SoundHDF5File():
|
||||
"""Collecting sound files to a HDF5 file
|
||||
|
||||
>>> f = SoundHDF5File('a.flac.h5', mode='a')
|
||||
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
|
||||
>>> f['id'] = (array, 16000)
|
||||
>>> array, rate = f['id']
|
||||
|
||||
|
||||
:param: str filepath:
|
||||
:param: str mode:
|
||||
:param: str format: The type used when saving wav. flac, nist, htk, etc.
|
||||
:param: str dtype:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filepath,
|
||||
mode="r+",
|
||||
format=None,
|
||||
dtype="int16",
|
||||
**kwargs):
|
||||
self.filepath = filepath
|
||||
self.mode = mode
|
||||
self.dtype = dtype
|
||||
|
||||
self.file = h5py.File(filepath, mode, **kwargs)
|
||||
if format is None:
|
||||
# filepath = a.flac.h5 -> format = flac
|
||||
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
|
||||
format = second_ext[1:]
|
||||
if format.upper() not in soundfile.available_formats():
|
||||
# If not found, flac is selected
|
||||
format = "flac"
|
||||
|
||||
# This format affects only saving
|
||||
self.format = format
|
||||
|
||||
def __repr__(self):
|
||||
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
|
||||
self.filepath, self.mode, self.format, self.dtype)
|
||||
|
||||
def create_dataset(self, name, shape=None, data=None, **kwds):
|
||||
f = io.BytesIO()
|
||||
array, rate = data
|
||||
soundfile.write(f, array, rate, format=self.format)
|
||||
self.file.create_dataset(
|
||||
name, shape=shape, data=np.void(f.getvalue()), **kwds)
|
||||
|
||||
def __setitem__(self, name, data):
|
||||
self.create_dataset(name, data=data)
|
||||
|
||||
def __getitem__(self, key):
|
||||
data = self.file[key][()]
|
||||
f = io.BytesIO(data.tobytes())
|
||||
array, rate = soundfile.read(f, dtype=self.dtype)
|
||||
return array, rate
|
||||
|
||||
def keys(self):
|
||||
return self.file.keys()
|
||||
|
||||
def values(self):
|
||||
for k in self.file:
|
||||
yield self[k]
|
||||
|
||||
def items(self):
|
||||
for k in self.file:
|
||||
yield k, self[k]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.file)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.file
|
||||
|
||||
def __len__(self, item):
|
||||
return len(self.file)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
@ -1,10 +0,0 @@
|
||||
[
|
||||
{
|
||||
"type": "shift",
|
||||
"params": {
|
||||
"min_shift_ms": -5,
|
||||
"max_shift_ms": 5
|
||||
},
|
||||
"prob": 1.0
|
||||
}
|
||||
]
|
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def main(args):
|
||||
with open(args.json_file, 'r') as fin:
|
||||
data_json = json.load(fin)
|
||||
|
||||
# manifest format:
|
||||
# {"input": [
|
||||
# {"feat": "dev/deltafalse/feats.1.ark:842920", "name": "input1", "shape": [349, 83]}
|
||||
# ],
|
||||
# "output": [
|
||||
# {"name": "target1", "shape": [12, 5002], "text": "NO APOLLO", "token": "▁NO ▁A PO LL O", "tokenid": "3144 482 352 269 317"}
|
||||
# ],
|
||||
# "utt2spk": "116-288045",
|
||||
# "utt": "116-288045-0019"}
|
||||
with open(args.manifest_file, 'w') as fout:
|
||||
for key, value in data_json['utts'].items():
|
||||
value['utt'] = key
|
||||
fout.write(json.dumps(value, ensure_ascii=False))
|
||||
fout.write("\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--json-file', type=str, default=None, help="espnet data json file.")
|
||||
parser.add_argument(
|
||||
'--manifest-file',
|
||||
type=str,
|
||||
default='maniefst.train',
|
||||
help='manifest data json line file.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -0,0 +1,3 @@
|
||||
# Punctation Restoration
|
||||
|
||||
Please using `https://github.com/745165806/PaddleSpeechTask` to do this task.
|
@ -0,0 +1,3 @@
|
||||
TED-En-Zh
|
||||
data
|
||||
exp
|
@ -0,0 +1,10 @@
|
||||
|
||||
# TED En-Zh
|
||||
|
||||
## Dataset
|
||||
|
||||
| Data Subset | Duration in Seconds |
|
||||
| --- | --- |
|
||||
| data/manifest.train | 0.942 ~ 60 |
|
||||
| data/manifest.dev | 1.151 ~ 39 |
|
||||
| data/manifest.test | 1.1 ~ 42.746 |
|
@ -0,0 +1 @@
|
||||
build
|
@ -1,77 +1,56 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
project(deepspeech VERSION 0.1)
|
||||
project(speechnn VERSION 0.1)
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE on)
|
||||
# set std-14
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/src CACHE PATH "Install path prefix." FORCE)
|
||||
endif(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
|
||||
|
||||
# include file
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
# fc_patch dir
|
||||
set(FETCHCONTENT_QUIET off)
|
||||
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
|
||||
set(FETCHCONTENT_BASE_DIR ${fc_patch})
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Option Configurations
|
||||
###############################################################################
|
||||
# option configurations
|
||||
option(TEST_DEBUG "option for debug" OFF)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Include third party
|
||||
###############################################################################
|
||||
# #example for include third party
|
||||
# FetchContent_Declare()
|
||||
# # FetchContent_MakeAvailable was not added until CMake 3.14
|
||||
# FetchContent_MakeAvailable()
|
||||
# include_directories()
|
||||
|
||||
# ABSEIL-CPP
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
absl
|
||||
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
|
||||
GIT_TAG "20210324.1"
|
||||
)
|
||||
FetchContent_MakeAvailable(absl)
|
||||
include(cmake/third_party.cmake)
|
||||
|
||||
# libsndfile
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
libsndfile
|
||||
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
|
||||
GIT_TAG "1.0.31"
|
||||
)
|
||||
FetchContent_MakeAvailable(libsndfile)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Add local library
|
||||
###############################################################################
|
||||
# system lib
|
||||
find_package()
|
||||
# if dir have CmakeLists.txt
|
||||
add_subdirectory()
|
||||
# if dir do not have CmakeLists.txt
|
||||
add_library(lib_name STATIC file.cc)
|
||||
target_link_libraries(lib_name item0 item1)
|
||||
add_dependencies(lib_name depend-target)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Library installation
|
||||
###############################################################################
|
||||
install()
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE on)
|
||||
# set std-14
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
###############################################################################
|
||||
# Build binary file
|
||||
###############################################################################
|
||||
add_executable()
|
||||
target_link_libraries()
|
||||
|
||||
# # fc_patch dir
|
||||
# set(FETCHCONTENT_QUIET off)
|
||||
# get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
|
||||
# set(FETCHCONTENT_BASE_DIR ${fc_patch})
|
||||
#
|
||||
#
|
||||
# ###############################################################################
|
||||
# # Option Configurations
|
||||
# ###############################################################################
|
||||
# # option configurations
|
||||
# option(TEST_DEBUG "option for debug" OFF)
|
||||
#
|
||||
#
|
||||
# ###############################################################################
|
||||
# # Add local library
|
||||
# ###############################################################################
|
||||
# # system lib
|
||||
# find_package()
|
||||
# # if dir have CmakeLists.txt
|
||||
# add_subdirectory()
|
||||
# # if dir do not have CmakeLists.txt
|
||||
# add_library(lib_name STATIC file.cc)
|
||||
# target_link_libraries(lib_name item0 item1)
|
||||
# add_dependencies(lib_name depend-target)
|
||||
#
|
||||
#
|
||||
# ###############################################################################
|
||||
# # Library installation
|
||||
# ###############################################################################
|
||||
# install()
|
||||
#
|
||||
#
|
||||
# ###############################################################################
|
||||
# # Build binary file
|
||||
# ###############################################################################
|
||||
# add_executable()
|
||||
# target_link_libraries()
|
||||
#
|
||||
|
@ -0,0 +1,197 @@
|
||||
include(ExternalProject)
|
||||
# Creat a target named "third_party", which can compile external dependencies on all platform(windows/linux/mac)
|
||||
|
||||
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
|
||||
"A path setting third party libraries download & build directories.")
|
||||
set(THIRD_PARTY_CACHE_PATH "${CMAKE_SOURCE_DIR}" CACHE STRING
|
||||
"A path cache third party source code to avoid repeated download.")
|
||||
|
||||
set(THIRD_PARTY_BUILD_TYPE Release)
|
||||
set(third_party_deps)
|
||||
|
||||
|
||||
# cache funciton to avoid repeat download code of third_party.
|
||||
# This function has 4 parameters, URL / REPOSITOR / TAG / DIR:
|
||||
# 1. URL: specify download url of 3rd party
|
||||
# 2. REPOSITORY: specify git REPOSITORY of 3rd party
|
||||
# 3. TAG: specify git tag/branch/commitID of 3rd party
|
||||
# 4. DIR: overwrite the original SOURCE_DIR when cache directory
|
||||
#
|
||||
# The function Return 1 PARENT_SCOPE variables:
|
||||
# - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
|
||||
# and you no longer need to set any donwnload steps in ExternalProject_Add.
|
||||
# For example:
|
||||
# Cache_third_party(${TARGET}
|
||||
# REPOSITORY ${TARGET_REPOSITORY}
|
||||
# TAG ${TARGET_TAG}
|
||||
# DIR ${TARGET_SOURCE_DIR})
|
||||
|
||||
FUNCTION(cache_third_party TARGET)
|
||||
SET(options "")
|
||||
SET(oneValueArgs URL REPOSITORY TAG DIR)
|
||||
SET(multiValueArgs "")
|
||||
cmake_parse_arguments(cache_third_party "${optionps}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
STRING(REPLACE "extern_" "" TARGET_NAME ${TARGET})
|
||||
STRING(REGEX REPLACE "[0-9]+" "" TARGET_NAME ${TARGET_NAME})
|
||||
STRING(TOUPPER ${TARGET_NAME} TARGET_NAME)
|
||||
IF(cache_third_party_REPOSITORY)
|
||||
SET(${TARGET_NAME}_DOWNLOAD_CMD
|
||||
GIT_REPOSITORY ${cache_third_party_REPOSITORY})
|
||||
IF(cache_third_party_TAG)
|
||||
LIST(APPEND ${TARGET_NAME}_DOWNLOAD_CMD
|
||||
GIT_TAG ${cache_third_party_TAG})
|
||||
ENDIF()
|
||||
ELSEIF(cache_third_party_URL)
|
||||
SET(${TARGET_NAME}_DOWNLOAD_CMD
|
||||
URL ${cache_third_party_URL})
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "Download link (Git repo or URL) must be specified for cache!")
|
||||
ENDIF()
|
||||
IF(WITH_TP_CACHE)
|
||||
IF(NOT cache_third_party_DIR)
|
||||
MESSAGE(FATAL_ERROR "Please input the ${TARGET_NAME}_SOURCE_DIR for overwriting when -DWITH_TP_CACHE=ON")
|
||||
ENDIF()
|
||||
# Generate and verify cache dir for third_party source code
|
||||
SET(cache_third_party_REPOSITORY ${cache_third_party_REPOSITORY} ${cache_third_party_URL})
|
||||
IF(cache_third_party_REPOSITORY AND cache_third_party_TAG)
|
||||
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
|
||||
STRING(MD5 HASH_GIT ${cache_third_party_TAG})
|
||||
STRING(SUBSTRING ${HASH_REPO} 0 8 HASH_REPO)
|
||||
STRING(SUBSTRING ${HASH_GIT} 0 8 HASH_GIT)
|
||||
STRING(CONCAT HASH ${HASH_REPO} ${HASH_GIT})
|
||||
# overwrite the original SOURCE_DIR when cache directory
|
||||
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
|
||||
ELSEIF(cache_third_party_REPOSITORY)
|
||||
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
|
||||
STRING(SUBSTRING ${HASH_REPO} 0 16 HASH)
|
||||
# overwrite the original SOURCE_DIR when cache directory
|
||||
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
|
||||
ENDIF()
|
||||
|
||||
IF(EXISTS ${${cache_third_party_DIR}})
|
||||
# judge whether the cache dir is empty
|
||||
FILE(GLOB files ${${cache_third_party_DIR}}/*)
|
||||
LIST(LENGTH files files_len)
|
||||
IF(files_len GREATER 0)
|
||||
list(APPEND ${TARGET_NAME}_DOWNLOAD_CMD DOWNLOAD_COMMAND "")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
SET(${cache_third_party_DIR} ${${cache_third_party_DIR}} PARENT_SCOPE)
|
||||
ENDIF()
|
||||
|
||||
# Pass ${TARGET_NAME}_DOWNLOAD_CMD to parent scope, the double quotation marks can't be removed
|
||||
SET(${TARGET_NAME}_DOWNLOAD_CMD "${${TARGET_NAME}_DOWNLOAD_CMD}" PARENT_SCOPE)
|
||||
ENDFUNCTION()
|
||||
|
||||
MACRO(UNSET_VAR VAR_NAME)
|
||||
UNSET(${VAR_NAME} CACHE)
|
||||
UNSET(${VAR_NAME})
|
||||
ENDMACRO()
|
||||
|
||||
# Funciton to Download the dependencies during compilation
|
||||
# This function has 2 parameters, URL / DIRNAME:
|
||||
# 1. URL: The download url of 3rd dependencies
|
||||
# 2. NAME: The name of file, that determin the dirname
|
||||
#
|
||||
FUNCTION(file_download_and_uncompress URL NAME)
|
||||
set(options "")
|
||||
set(oneValueArgs MD5)
|
||||
set(multiValueArgs "")
|
||||
cmake_parse_arguments(URL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
MESSAGE(STATUS "Download dependence[${NAME}] from ${URL}, MD5: ${URL_MD5}")
|
||||
SET(${NAME}_INCLUDE_DIR ${THIRD_PARTY_PATH}/${NAME}/data PARENT_SCOPE)
|
||||
ExternalProject_Add(
|
||||
download_${NAME}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${THIRD_PARTY_PATH}/${NAME}
|
||||
URL ${URL}
|
||||
URL_MD5 ${URL_MD5}
|
||||
TIMEOUT 120
|
||||
DOWNLOAD_DIR ${THIRD_PARTY_PATH}/${NAME}/data/
|
||||
SOURCE_DIR ${THIRD_PARTY_PATH}/${NAME}/data/
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
UPDATE_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
)
|
||||
set(third_party_deps ${third_party_deps} download_${NAME} PARENT_SCOPE)
|
||||
ENDFUNCTION()
|
||||
|
||||
|
||||
# Correction of flags on different Platform(WIN/MAC) and Print Warning Message
|
||||
if (APPLE)
|
||||
if(WITH_MKL)
|
||||
MESSAGE(WARNING
|
||||
"Mac is not supported with MKL in Paddle yet. Force WITH_MKL=OFF.")
|
||||
set(WITH_MKL OFF CACHE STRING "Disable MKL for building on mac" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(WIN32 OR APPLE)
|
||||
MESSAGE(STATUS "Disable XBYAK in Windows and MacOS")
|
||||
SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
|
||||
|
||||
if(WITH_LIBXSMM)
|
||||
MESSAGE(WARNING
|
||||
"Windows, Mac are not supported with libxsmm in Paddle yet."
|
||||
"Force WITH_LIBXSMM=OFF")
|
||||
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM in Windows and MacOS" FORCE)
|
||||
endif()
|
||||
|
||||
if(WITH_BOX_PS)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with BOX_PS in Paddle yet."
|
||||
"Force WITH_BOX_PS=OFF")
|
||||
SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
|
||||
endif()
|
||||
|
||||
if(WITH_PSLIB)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with PSLIB in Paddle yet."
|
||||
"Force WITH_PSLIB=OFF")
|
||||
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
|
||||
endif()
|
||||
|
||||
if(WITH_LIBMCT)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with LIBMCT in Paddle yet."
|
||||
"Force WITH_LIBMCT=OFF")
|
||||
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
|
||||
endif()
|
||||
|
||||
if(WITH_PSLIB_BRPC)
|
||||
MESSAGE(WARNING
|
||||
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
|
||||
"Force WITH_PSLIB_BRPC=OFF")
|
||||
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(WITH_MKLML ${WITH_MKL})
|
||||
if(NOT DEFINED WITH_MKLDNN)
|
||||
if(WITH_MKL AND AVX2_FOUND)
|
||||
set(WITH_MKLDNN ON)
|
||||
else()
|
||||
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
|
||||
set(WITH_MKLDNN OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
|
||||
set(WITH_DGC OFF)
|
||||
endif()
|
||||
|
||||
if(${CMAKE_VERSION} VERSION_GREATER "3.5.2")
|
||||
set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects
|
||||
endif()
|
||||
|
||||
|
||||
########################### include third_party according to flags ###############################
|
||||
include(third_party/libsndfile) # download, build, install libsndfile
|
||||
include(third_party/boost) # download boost
|
||||
include(third_party/eigen) # download eigen3
|
||||
include(third_party/threadpool) # download threadpool
|
||||
|
||||
|
@ -0,0 +1,13 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
include(ExternalProject)
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
absl
|
||||
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
|
||||
GIT_TAG "20210324.1"
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(absl)
|
||||
|
||||
|
@ -0,0 +1,49 @@
|
||||
include(ExternalProject)
|
||||
|
||||
set(BOOST_PROJECT "extern_boost")
|
||||
# To release PaddlePaddle as a pip package, we have to follow the
|
||||
# manylinux1 standard, which features as old Linux kernels and
|
||||
# compilers as possible and recommends CentOS 5. Indeed, the earliest
|
||||
# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new
|
||||
# version of boost, say, 1.66.0, doesn't build on CentOS 6. We
|
||||
# checked that the devtools package of CentOS 6 installs boost 1.41.0.
|
||||
# So we use 1.41.0 here.
|
||||
set(BOOST_VER "1.41.0")
|
||||
set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
|
||||
set(BOOST_URL "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
|
||||
|
||||
MESSAGE(STATUS "BOOST_VERSION: ${BOOST_VER}, BOOST_URL: ${BOOST_URL}")
|
||||
|
||||
set(BOOST_PREFIX_DIR ${THIRD_PARTY_PATH}/boost)
|
||||
set(BOOST_SOURCE_DIR ${THIRD_PARTY_PATH}/boost/src/extern_boost)
|
||||
cache_third_party(${BOOST_PROJECT}
|
||||
URL ${BOOST_URL}
|
||||
DIR BOOST_SOURCE_DIR)
|
||||
|
||||
set(BOOST_INCLUDE_DIR "${BOOST_SOURCE_DIR}" CACHE PATH "boost include directory." FORCE)
|
||||
set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
|
||||
include_directories(${BOOST_INCLUDE_DIR})
|
||||
|
||||
if(WIN32 AND MSVC_VERSION GREATER_EQUAL 1600)
|
||||
add_definitions(-DBOOST_HAS_STATIC_ASSERT)
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
${BOOST_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
"${BOOST_DOWNLOAD_CMD}"
|
||||
URL_MD5 f891e8c2c9424f0565f0129ad9ab4aff
|
||||
PREFIX ${BOOST_PREFIX_DIR}
|
||||
DOWNLOAD_DIR ${BOOST_SOURCE_DIR}
|
||||
SOURCE_DIR ${BOOST_SOURCE_DIR}
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
UPDATE_COMMAND ""
|
||||
)
|
||||
|
||||
add_library(boost INTERFACE)
|
||||
|
||||
add_dependencies(boost ${BOOST_PROJECT})
|
||||
set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR})
|
@ -0,0 +1,53 @@
|
||||
include(ExternalProject)
|
||||
|
||||
# update eigen to the commit id f612df27 on 03/16/2021
|
||||
set(EIGEN_PREFIX_DIR ${THIRD_PARTY_PATH}/eigen3)
|
||||
set(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3/src/extern_eigen3)
|
||||
set(EIGEN_REPOSITORY https://gitlab.com/libeigen/eigen.git)
|
||||
set(EIGEN_TAG f612df273689a19d25b45ca4f8269463207c4fee)
|
||||
|
||||
cache_third_party(extern_eigen3
|
||||
REPOSITORY ${EIGEN_REPOSITORY}
|
||||
TAG ${EIGEN_TAG}
|
||||
DIR EIGEN_SOURCE_DIR)
|
||||
|
||||
if(WIN32)
|
||||
add_definitions(-DEIGEN_STRONG_INLINE=inline)
|
||||
elseif(LINUX)
|
||||
if(WITH_ROCM)
|
||||
# For HIPCC Eigen::internal::device::numeric_limits is not EIGEN_DEVICE_FUNC
|
||||
# which will cause compiler error of using __host__ funciont in __host__ __device__
|
||||
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src)
|
||||
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst)
|
||||
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorReductionGpu.h native_src1)
|
||||
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h native_dst1)
|
||||
set(EIGEN_PATCH_COMMAND cp ${native_src} ${native_dst} && cp ${native_src1} ${native_dst1})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR})
|
||||
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_eigen3
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
${SHALLOW_CLONE}
|
||||
"${EIGEN_DOWNLOAD_CMD}"
|
||||
PREFIX ${EIGEN_PREFIX_DIR}
|
||||
SOURCE_DIR ${EIGEN_SOURCE_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
PATCH_COMMAND ${EIGEN_PATCH_COMMAND}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
add_library(eigen3 INTERFACE)
|
||||
|
||||
add_dependencies(eigen3 extern_eigen3)
|
||||
|
||||
# sw not support thread_local semantic
|
||||
if(WITH_SW)
|
||||
add_definitions(-DEIGEN_AVOID_THREAD_LOCAL)
|
||||
endif()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue