Merge pull request #743 from LittleChenCc/develop
ST-related code and script for TED-En-Zh translationpull/748/head
commit
566f636cc6
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Export for U2 model."""
|
||||||
|
from deepspeech.exps.u2_st.config import get_cfg_defaults
|
||||||
|
from deepspeech.exps.u2_st.model import U2STTester as Tester
|
||||||
|
from deepspeech.training.cli import default_argument_parser
|
||||||
|
from deepspeech.utils.utility import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Tester(config, args)
|
||||||
|
exp.setup()
|
||||||
|
exp.run_export()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = get_cfg_defaults()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
if args.dump_config:
|
||||||
|
with open(args.dump_config, 'w') as f:
|
||||||
|
print(config, file=f)
|
||||||
|
|
||||||
|
main(config, args)
|
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Evaluation for U2 model."""
|
||||||
|
import cProfile
|
||||||
|
|
||||||
|
from deepspeech.exps.u2_st.config import get_cfg_defaults
|
||||||
|
from deepspeech.exps.u2_st.model import U2STTester as Tester
|
||||||
|
from deepspeech.training.cli import default_argument_parser
|
||||||
|
from deepspeech.utils.utility import print_arguments
|
||||||
|
|
||||||
|
# TODO(hui zhang): dynamic load
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Tester(config, args)
|
||||||
|
exp.setup()
|
||||||
|
exp.run_test()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = get_cfg_defaults()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
if args.dump_config:
|
||||||
|
with open(args.dump_config, 'w') as f:
|
||||||
|
print(config, file=f)
|
||||||
|
|
||||||
|
# Setting for profiling
|
||||||
|
pr = cProfile.Profile()
|
||||||
|
pr.runcall(main, config, args)
|
||||||
|
pr.dump_stats('test.profile')
|
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Trainer for U2 model."""
|
||||||
|
import cProfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
from paddle import distributed as dist
|
||||||
|
|
||||||
|
from deepspeech.exps.u2_st.config import get_cfg_defaults
|
||||||
|
from deepspeech.exps.u2_st.model import U2STTrainer as Trainer
|
||||||
|
from deepspeech.training.cli import default_argument_parser
|
||||||
|
from deepspeech.utils.utility import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Trainer(config, args)
|
||||||
|
exp.setup()
|
||||||
|
exp.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
if args.device == "gpu" and args.nprocs > 1:
|
||||||
|
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||||
|
else:
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = get_cfg_defaults()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
if args.dump_config:
|
||||||
|
with open(args.dump_config, 'w') as f:
|
||||||
|
print(config, file=f)
|
||||||
|
|
||||||
|
# Setting for profiling
|
||||||
|
pr = cProfile.Profile()
|
||||||
|
pr.runcall(main, config, args)
|
||||||
|
pr.dump_stats(os.path.join(args.output, 'train.profile'))
|
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from deepspeech.exps.u2_st.model import U2STTester
|
||||||
|
from deepspeech.exps.u2_st.model import U2STTrainer
|
||||||
|
from deepspeech.io.collator_st import SpeechCollator
|
||||||
|
from deepspeech.io.dataset import ManifestDataset
|
||||||
|
from deepspeech.models.u2_st import U2STModel
|
||||||
|
|
||||||
|
_C = CfgNode()
|
||||||
|
|
||||||
|
_C.data = ManifestDataset.params()
|
||||||
|
|
||||||
|
_C.collator = SpeechCollator.params()
|
||||||
|
|
||||||
|
_C.model = U2STModel.params()
|
||||||
|
|
||||||
|
_C.training = U2STTrainer.params()
|
||||||
|
|
||||||
|
_C.decoding = U2STTester.params()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg_defaults():
|
||||||
|
"""Get a yacs CfgNode object with default values for my_project."""
|
||||||
|
# Return a clone so that the defaults will not be altered
|
||||||
|
# This is for the "local variable" use pattern
|
||||||
|
config = _C.clone()
|
||||||
|
config.set_new_allowed(True)
|
||||||
|
return config
|
@ -0,0 +1,680 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Contains U2 model."""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from deepspeech.io.collator_st import KaldiPrePorocessedCollator
|
||||||
|
from deepspeech.io.collator_st import SpeechCollator
|
||||||
|
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
|
||||||
|
from deepspeech.io.collator_st import TripletSpeechCollator
|
||||||
|
from deepspeech.io.dataset import ManifestDataset
|
||||||
|
from deepspeech.io.dataset import TripletManifestDataset
|
||||||
|
from deepspeech.io.sampler import SortagradBatchSampler
|
||||||
|
from deepspeech.io.sampler import SortagradDistributedBatchSampler
|
||||||
|
from deepspeech.models.u2_st import U2STModel
|
||||||
|
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
|
||||||
|
from deepspeech.training.scheduler import WarmupLR
|
||||||
|
from deepspeech.training.trainer import Trainer
|
||||||
|
from deepspeech.utils import bleu_score
|
||||||
|
from deepspeech.utils import ctc_utils
|
||||||
|
from deepspeech.utils import error_rate
|
||||||
|
from deepspeech.utils import layer_tools
|
||||||
|
from deepspeech.utils import mp_tools
|
||||||
|
from deepspeech.utils import text_grid
|
||||||
|
from deepspeech.utils import utility
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
class U2STTrainer(Trainer):
|
||||||
|
@classmethod
|
||||||
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||||
|
# training config
|
||||||
|
default = CfgNode(
|
||||||
|
dict(
|
||||||
|
n_epoch=50, # train epochs
|
||||||
|
log_interval=100, # steps
|
||||||
|
accum_grad=1, # accum grad by # steps
|
||||||
|
global_grad_clip=5.0, # the global norm clip
|
||||||
|
))
|
||||||
|
default.optim = 'adam'
|
||||||
|
default.optim_conf = CfgNode(
|
||||||
|
dict(
|
||||||
|
lr=5e-4, # learning rate
|
||||||
|
weight_decay=1e-6, # the coeff of weight decay
|
||||||
|
))
|
||||||
|
default.scheduler = 'warmuplr'
|
||||||
|
default.scheduler_conf = CfgNode(
|
||||||
|
dict(
|
||||||
|
warmup_steps=25000,
|
||||||
|
lr_decay=1.0, # learning rate decay
|
||||||
|
))
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config.merge_from_other_cfg(default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __init__(self, config, args):
|
||||||
|
super().__init__(config, args)
|
||||||
|
|
||||||
|
def train_batch(self, batch_index, batch_data, msg):
|
||||||
|
train_conf = self.config.training
|
||||||
|
start = time.time()
|
||||||
|
utt, audio, audio_len, text, text_len = batch_data
|
||||||
|
if isinstance(text, list) and isinstance(text_len, list):
|
||||||
|
# joint training with ASR. Two decoding texts [translation, transcription]
|
||||||
|
text, text_transcript = text
|
||||||
|
text_len, text_transcript_len = text_len
|
||||||
|
loss, st_loss, attention_loss, ctc_loss = self.model(
|
||||||
|
audio, audio_len, text, text_len, text_transcript,
|
||||||
|
text_transcript_len)
|
||||||
|
else:
|
||||||
|
loss, st_loss, attention_loss, ctc_loss = self.model(
|
||||||
|
audio, audio_len, text, text_len)
|
||||||
|
# loss div by `batch_size * accum_grad`
|
||||||
|
loss /= train_conf.accum_grad
|
||||||
|
loss.backward()
|
||||||
|
layer_tools.print_grads(self.model, print_func=None)
|
||||||
|
|
||||||
|
losses_np = {'loss': float(loss) * train_conf.accum_grad}
|
||||||
|
losses_np['st_loss'] = float(st_loss)
|
||||||
|
if attention_loss:
|
||||||
|
losses_np['att_loss'] = float(attention_loss)
|
||||||
|
if ctc_loss:
|
||||||
|
losses_np['ctc_loss'] = float(ctc_loss)
|
||||||
|
|
||||||
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.clear_grad()
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.iteration += 1
|
||||||
|
|
||||||
|
iteration_time = time.time() - start
|
||||||
|
|
||||||
|
if (batch_index + 1) % train_conf.log_interval == 0:
|
||||||
|
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
||||||
|
msg += "batch size: {}, ".format(self.config.collator.batch_size)
|
||||||
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
||||||
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_np.items())
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0 and self.visualizer:
|
||||||
|
losses_np_v = losses_np.copy()
|
||||||
|
losses_np_v.update({"lr": self.lr_scheduler()})
|
||||||
|
self.visualizer.add_scalars("step", losses_np_v,
|
||||||
|
self.iteration - 1)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def valid(self):
|
||||||
|
self.model.eval()
|
||||||
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||||
|
valid_losses = defaultdict(list)
|
||||||
|
num_seen_utts = 1
|
||||||
|
total_loss = 0.0
|
||||||
|
for i, batch in enumerate(self.valid_loader):
|
||||||
|
utt, audio, audio_len, text, text_len = batch
|
||||||
|
if isinstance(text, list) and isinstance(text_len, list):
|
||||||
|
text, text_transcript = text
|
||||||
|
text_len, text_transcript_len = text_len
|
||||||
|
loss, st_loss, attention_loss, ctc_loss = self.model(
|
||||||
|
audio, audio_len, text, text_len, text_transcript,
|
||||||
|
text_transcript_len)
|
||||||
|
else:
|
||||||
|
loss, st_loss, attention_loss, ctc_loss = self.model(
|
||||||
|
audio, audio_len, text, text_len)
|
||||||
|
if paddle.isfinite(loss):
|
||||||
|
num_utts = batch[1].shape[0]
|
||||||
|
num_seen_utts += num_utts
|
||||||
|
total_loss += float(st_loss) * num_utts
|
||||||
|
valid_losses['val_loss'].append(float(st_loss))
|
||||||
|
if attention_loss:
|
||||||
|
valid_losses['val_att_loss'].append(float(attention_loss))
|
||||||
|
if ctc_loss:
|
||||||
|
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
||||||
|
|
||||||
|
if (i + 1) % self.config.training.log_interval == 0:
|
||||||
|
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||||
|
valid_dump['val_history_st_loss'] = total_loss / num_seen_utts
|
||||||
|
|
||||||
|
# logging
|
||||||
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||||
|
msg += "epoch: {}, ".format(self.epoch)
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
|
||||||
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in valid_dump.items())
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
logger.info('Rank {} Val info st_val_loss {}'.format(
|
||||||
|
dist.get_rank(), total_loss / num_seen_utts))
|
||||||
|
return total_loss, num_seen_utts
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""The training process control by step."""
|
||||||
|
# !!!IMPORTANT!!!
|
||||||
|
# Try to export the model by script, if fails, we should refine
|
||||||
|
# the code to satisfy the script export requirements
|
||||||
|
# script_model = paddle.jit.to_static(self.model)
|
||||||
|
# script_model_path = str(self.checkpoint_dir / 'init')
|
||||||
|
# paddle.jit.save(script_model, script_model_path)
|
||||||
|
|
||||||
|
from_scratch = self.resume_or_scratch()
|
||||||
|
if from_scratch:
|
||||||
|
# save init model, i.e. 0 epoch
|
||||||
|
self.save(tag='init')
|
||||||
|
|
||||||
|
self.lr_scheduler.step(self.iteration)
|
||||||
|
if self.parallel:
|
||||||
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||||
|
|
||||||
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
||||||
|
while self.epoch < self.config.training.n_epoch:
|
||||||
|
self.model.train()
|
||||||
|
try:
|
||||||
|
data_start_time = time.time()
|
||||||
|
for batch_index, batch in enumerate(self.train_loader):
|
||||||
|
dataload_time = time.time() - data_start_time
|
||||||
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
||||||
|
msg += "epoch: {}, ".format(self.epoch)
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
msg += "batch : {}/{}, ".format(batch_index + 1,
|
||||||
|
len(self.train_loader))
|
||||||
|
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
||||||
|
msg += "data time: {:>.3f}s, ".format(dataload_time)
|
||||||
|
self.train_batch(batch_index, batch, msg)
|
||||||
|
data_start_time = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
total_loss, num_seen_utts = self.valid()
|
||||||
|
if dist.get_world_size() > 1:
|
||||||
|
num_seen_utts = paddle.to_tensor(num_seen_utts)
|
||||||
|
# the default operator in all_reduce function is sum.
|
||||||
|
dist.all_reduce(num_seen_utts)
|
||||||
|
total_loss = paddle.to_tensor(total_loss)
|
||||||
|
dist.all_reduce(total_loss)
|
||||||
|
cv_loss = total_loss / num_seen_utts
|
||||||
|
cv_loss = float(cv_loss)
|
||||||
|
else:
|
||||||
|
cv_loss = total_loss / num_seen_utts
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
||||||
|
if self.visualizer:
|
||||||
|
self.visualizer.add_scalars(
|
||||||
|
'epoch', {'cv_loss': cv_loss,
|
||||||
|
'lr': self.lr_scheduler()}, self.epoch)
|
||||||
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
||||||
|
self.new_epoch()
|
||||||
|
|
||||||
|
def setup_dataloader(self):
|
||||||
|
config = self.config.clone()
|
||||||
|
config.defrost()
|
||||||
|
config.collator.keep_transcription_text = False
|
||||||
|
|
||||||
|
# train/valid dataset, return token ids
|
||||||
|
Dataset = TripletManifestDataset if config.model.model_conf.asr_weight > 0. else ManifestDataset
|
||||||
|
config.data.manifest = config.data.train_manifest
|
||||||
|
train_dataset = Dataset.from_config(config)
|
||||||
|
|
||||||
|
config.data.manifest = config.data.dev_manifest
|
||||||
|
dev_dataset = Dataset.from_config(config)
|
||||||
|
|
||||||
|
if config.collator.raw_wav:
|
||||||
|
if config.model.model_conf.asr_weight > 0.:
|
||||||
|
Collator = TripletSpeechCollator
|
||||||
|
TestCollator = SpeechCollator
|
||||||
|
else:
|
||||||
|
TestCollator = Collator = SpeechCollator
|
||||||
|
# Not yet implement the mtl loader for raw_wav.
|
||||||
|
else:
|
||||||
|
if config.model.model_conf.asr_weight > 0.:
|
||||||
|
Collator = TripletKaldiPrePorocessedCollator
|
||||||
|
TestCollator = KaldiPrePorocessedCollator
|
||||||
|
else:
|
||||||
|
TestCollator = Collator = KaldiPrePorocessedCollator
|
||||||
|
|
||||||
|
collate_fn_train = Collator.from_config(config)
|
||||||
|
|
||||||
|
config.collator.augmentation_config = ""
|
||||||
|
collate_fn_dev = Collator.from_config(config)
|
||||||
|
|
||||||
|
if self.parallel:
|
||||||
|
batch_sampler = SortagradDistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.collator.batch_size,
|
||||||
|
num_replicas=None,
|
||||||
|
rank=None,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
sortagrad=config.collator.sortagrad,
|
||||||
|
shuffle_method=config.collator.shuffle_method)
|
||||||
|
else:
|
||||||
|
batch_sampler = SortagradBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=config.collator.batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
sortagrad=config.collator.sortagrad,
|
||||||
|
shuffle_method=config.collator.shuffle_method)
|
||||||
|
self.train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
collate_fn=collate_fn_train,
|
||||||
|
num_workers=config.collator.num_workers, )
|
||||||
|
self.valid_loader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
batch_size=config.collator.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=collate_fn_dev)
|
||||||
|
|
||||||
|
# test dataset, return raw text
|
||||||
|
config.data.manifest = config.data.test_manifest
|
||||||
|
# filter test examples, will cause less examples, but no mismatch with training
|
||||||
|
# and can use large batch size , save training time, so filter test egs now.
|
||||||
|
# config.data.min_input_len = 0.0 # second
|
||||||
|
# config.data.max_input_len = float('inf') # second
|
||||||
|
# config.data.min_output_len = 0.0 # tokens
|
||||||
|
# config.data.max_output_len = float('inf') # tokens
|
||||||
|
# config.data.min_output_input_ratio = 0.00
|
||||||
|
# config.data.max_output_input_ratio = float('inf')
|
||||||
|
test_dataset = ManifestDataset.from_config(config)
|
||||||
|
# return text ord id
|
||||||
|
config.collator.keep_transcription_text = True
|
||||||
|
config.collator.augmentation_config = ""
|
||||||
|
self.test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=config.decoding.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=TestCollator.from_config(config))
|
||||||
|
# return text token id
|
||||||
|
config.collator.keep_transcription_text = False
|
||||||
|
self.align_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=config.decoding.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=TestCollator.from_config(config))
|
||||||
|
logger.info("Setup train/valid/test/align Dataloader!")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
config = self.config
|
||||||
|
model_conf = config.model
|
||||||
|
model_conf.defrost()
|
||||||
|
model_conf.input_dim = self.train_loader.collate_fn.feature_size
|
||||||
|
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
|
||||||
|
model_conf.freeze()
|
||||||
|
model = U2STModel.from_config(model_conf)
|
||||||
|
|
||||||
|
if self.parallel:
|
||||||
|
model = paddle.DataParallel(model)
|
||||||
|
|
||||||
|
logger.info(f"{model}")
|
||||||
|
layer_tools.print_params(model, logger.info)
|
||||||
|
|
||||||
|
train_config = config.training
|
||||||
|
optim_type = train_config.optim
|
||||||
|
optim_conf = train_config.optim_conf
|
||||||
|
scheduler_type = train_config.scheduler
|
||||||
|
scheduler_conf = train_config.scheduler_conf
|
||||||
|
|
||||||
|
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
|
||||||
|
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
|
||||||
|
|
||||||
|
if scheduler_type == 'expdecaylr':
|
||||||
|
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
||||||
|
learning_rate=optim_conf.lr,
|
||||||
|
gamma=scheduler_conf.lr_decay,
|
||||||
|
verbose=False)
|
||||||
|
elif scheduler_type == 'warmuplr':
|
||||||
|
lr_scheduler = WarmupLR(
|
||||||
|
learning_rate=optim_conf.lr,
|
||||||
|
warmup_steps=scheduler_conf.warmup_steps,
|
||||||
|
verbose=False)
|
||||||
|
elif scheduler_type == 'noam':
|
||||||
|
lr_scheduler = paddle.optimizer.lr.NoamDecay(
|
||||||
|
learning_rate=optim_conf.lr,
|
||||||
|
d_model=model_conf.encoder_conf.output_size,
|
||||||
|
warmup_steps=scheduler_conf.warmup_steps,
|
||||||
|
verbose=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not support scheduler: {scheduler_type}")
|
||||||
|
|
||||||
|
if optim_type == 'adam':
|
||||||
|
optimizer = paddle.optimizer.Adam(
|
||||||
|
learning_rate=lr_scheduler,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
grad_clip=grad_clip)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not support optim: {optim_type}")
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
logger.info("Setup model/optimizer/lr_scheduler!")
|
||||||
|
|
||||||
|
|
||||||
|
class U2STTester(U2STTrainer):
|
||||||
|
@classmethod
|
||||||
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||||
|
# decoding config
|
||||||
|
default = CfgNode(
|
||||||
|
dict(
|
||||||
|
alpha=2.5, # Coef of LM for beam search.
|
||||||
|
beta=0.3, # Coef of WC for beam search.
|
||||||
|
cutoff_prob=1.0, # Cutoff probability for pruning.
|
||||||
|
cutoff_top_n=40, # Cutoff number for pruning.
|
||||||
|
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
|
||||||
|
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
|
||||||
|
# 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||||
|
error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu'
|
||||||
|
num_proc_bsearch=8, # # of CPUs for beam search.
|
||||||
|
beam_size=10, # Beam search width.
|
||||||
|
batch_size=16, # decoding batch size
|
||||||
|
ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
|
||||||
|
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
|
||||||
|
# <0: for decoding, use full chunk.
|
||||||
|
# >0: for decoding, use fixed chunk size as set.
|
||||||
|
# 0: used for training, it's prohibited here.
|
||||||
|
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
|
||||||
|
simulate_streaming=False, # simulate streaming inference. Defaults to False.
|
||||||
|
))
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config.merge_from_other_cfg(default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __init__(self, config, args):
|
||||||
|
super().__init__(config, args)
|
||||||
|
|
||||||
|
def ordid2token(self, texts, texts_len):
|
||||||
|
""" ord() id to chr() chr """
|
||||||
|
trans = []
|
||||||
|
for text, n in zip(texts, texts_len):
|
||||||
|
n = n.numpy().item()
|
||||||
|
ids = text[:n]
|
||||||
|
trans.append(''.join([chr(i) for i in ids]))
|
||||||
|
return trans
|
||||||
|
|
||||||
|
def compute_translation_metrics(self,
|
||||||
|
utts,
|
||||||
|
audio,
|
||||||
|
audio_len,
|
||||||
|
texts,
|
||||||
|
texts_len,
|
||||||
|
bleu_func,
|
||||||
|
fout=None):
|
||||||
|
cfg = self.config.decoding
|
||||||
|
len_refs, num_ins = 0, 0
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
text_feature = self.test_loader.collate_fn.text_feature
|
||||||
|
|
||||||
|
refs = [
|
||||||
|
"".join(chr(t) for t in text[:text_len])
|
||||||
|
for text, text_len in zip(texts, texts_len)
|
||||||
|
]
|
||||||
|
# from IPython import embed
|
||||||
|
# import os
|
||||||
|
# embed()
|
||||||
|
# os._exit(0)
|
||||||
|
hyps = self.model.decode(
|
||||||
|
audio,
|
||||||
|
audio_len,
|
||||||
|
text_feature=text_feature,
|
||||||
|
decoding_method=cfg.decoding_method,
|
||||||
|
lang_model_path=cfg.lang_model_path,
|
||||||
|
beam_alpha=cfg.alpha,
|
||||||
|
beam_beta=cfg.beta,
|
||||||
|
beam_size=cfg.beam_size,
|
||||||
|
cutoff_prob=cfg.cutoff_prob,
|
||||||
|
cutoff_top_n=cfg.cutoff_top_n,
|
||||||
|
num_processes=cfg.num_proc_bsearch,
|
||||||
|
ctc_weight=cfg.ctc_weight,
|
||||||
|
decoding_chunk_size=cfg.decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
|
||||||
|
simulate_streaming=cfg.simulate_streaming)
|
||||||
|
decode_time = time.time() - start_time
|
||||||
|
|
||||||
|
for utt, target, result in zip(utts, refs, hyps):
|
||||||
|
len_refs += len(target.split())
|
||||||
|
num_ins += 1
|
||||||
|
if fout:
|
||||||
|
fout.write(utt + " " + result + "\n")
|
||||||
|
logger.info("\nReference: %s\nHypothesis: %s" % (target, result))
|
||||||
|
logger.info("One example BLEU = %s" %
|
||||||
|
(bleu_func([result], [[target]]).prec_str))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
hyps=hyps,
|
||||||
|
refs=refs,
|
||||||
|
bleu=bleu_func(hyps, [refs]).score,
|
||||||
|
len_refs=len_refs,
|
||||||
|
num_ins=num_ins, # num examples
|
||||||
|
num_frames=audio_len.sum().numpy().item(),
|
||||||
|
decode_time=decode_time)
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
@paddle.no_grad()
|
||||||
|
def test(self):
|
||||||
|
assert self.args.result_file
|
||||||
|
self.model.eval()
|
||||||
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||||
|
|
||||||
|
cfg = self.config.decoding
|
||||||
|
bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
|
||||||
|
|
||||||
|
stride_ms = self.test_loader.collate_fn.stride_ms
|
||||||
|
hyps, refs = [], []
|
||||||
|
len_refs, num_ins = 0, 0
|
||||||
|
num_frames = 0.0
|
||||||
|
num_time = 0.0
|
||||||
|
with open(self.args.result_file, 'w') as fout:
|
||||||
|
for i, batch in enumerate(self.test_loader):
|
||||||
|
metrics = self.compute_translation_metrics(
|
||||||
|
*batch, bleu_func=bleu_func, fout=fout)
|
||||||
|
hyps += metrics['hyps']
|
||||||
|
refs += metrics['refs']
|
||||||
|
bleu = metrics['bleu']
|
||||||
|
num_frames += metrics['num_frames']
|
||||||
|
num_time += metrics["decode_time"]
|
||||||
|
len_refs += metrics['len_refs']
|
||||||
|
num_ins += metrics['num_ins']
|
||||||
|
rtf = num_time / (num_frames * stride_ms)
|
||||||
|
logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu))
|
||||||
|
|
||||||
|
rtf = num_time / (num_frames * stride_ms)
|
||||||
|
msg = "Test: "
|
||||||
|
msg += "epoch: {}, ".format(self.epoch)
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
msg += "RTF: {}, ".format(rtf)
|
||||||
|
msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs])))
|
||||||
|
logger.info(msg)
|
||||||
|
bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu'
|
||||||
|
err_type_str = "BLEU"
|
||||||
|
with open(bleu_meta_path, 'w') as f:
|
||||||
|
data = json.dumps({
|
||||||
|
"epoch":
|
||||||
|
self.epoch,
|
||||||
|
"step":
|
||||||
|
self.iteration,
|
||||||
|
"rtf":
|
||||||
|
rtf,
|
||||||
|
err_type_str:
|
||||||
|
bleu_func(hyps, [refs]).score,
|
||||||
|
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
|
||||||
|
"process_hour":
|
||||||
|
num_time / 1000.0 / 3600.0,
|
||||||
|
"num_examples":
|
||||||
|
num_ins,
|
||||||
|
"decode_method":
|
||||||
|
self.config.decoding.decoding_method,
|
||||||
|
})
|
||||||
|
f.write(data + '\n')
|
||||||
|
|
||||||
|
def run_test(self):
|
||||||
|
self.resume_or_scratch()
|
||||||
|
try:
|
||||||
|
self.test()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def align(self):
|
||||||
|
if self.config.decoding.batch_size > 1:
|
||||||
|
logger.fatal('alignment mode must be running with batch_size == 1')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# xxx.align
|
||||||
|
assert self.args.result_file and self.args.result_file.endswith(
|
||||||
|
'.align')
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
||||||
|
|
||||||
|
stride_ms = self.align_loader.collate_fn.stride_ms
|
||||||
|
token_dict = self.align_loader.collate_fn.vocab_list
|
||||||
|
with open(self.args.result_file, 'w') as fout:
|
||||||
|
# one example in batch
|
||||||
|
for i, batch in enumerate(self.align_loader):
|
||||||
|
key, feat, feats_length, target, target_length = batch
|
||||||
|
|
||||||
|
# 1. Encoder
|
||||||
|
encoder_out, encoder_mask = self.model._forward_encoder(
|
||||||
|
feat, feats_length) # (B, maxlen, encoder_dim)
|
||||||
|
maxlen = encoder_out.size(1)
|
||||||
|
ctc_probs = self.model.ctc.log_softmax(
|
||||||
|
encoder_out) # (1, maxlen, vocab_size)
|
||||||
|
|
||||||
|
# 2. alignment
|
||||||
|
ctc_probs = ctc_probs.squeeze(0)
|
||||||
|
target = target.squeeze(0)
|
||||||
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
||||||
|
logger.info("align ids", key[0], alignment)
|
||||||
|
fout.write('{} {}\n'.format(key[0], alignment))
|
||||||
|
|
||||||
|
# 3. gen praat
|
||||||
|
# segment alignment
|
||||||
|
align_segs = text_grid.segment_alignment(alignment)
|
||||||
|
logger.info("align tokens", key[0], align_segs)
|
||||||
|
# IntervalTier, List["start end token\n"]
|
||||||
|
subsample = utility.get_subsample(self.config)
|
||||||
|
tierformat = text_grid.align_to_tierformat(
|
||||||
|
align_segs, subsample, token_dict)
|
||||||
|
# write tier
|
||||||
|
align_output_path = os.path.join(
|
||||||
|
os.path.dirname(self.args.result_file), "align")
|
||||||
|
tier_path = os.path.join(align_output_path, key[0] + ".tier")
|
||||||
|
with open(tier_path, 'w') as f:
|
||||||
|
f.writelines(tierformat)
|
||||||
|
# write textgrid
|
||||||
|
textgrid_path = os.path.join(align_output_path,
|
||||||
|
key[0] + ".TextGrid")
|
||||||
|
second_per_frame = 1. / (1000. /
|
||||||
|
stride_ms) # 25ms window, 10ms stride
|
||||||
|
second_per_example = (
|
||||||
|
len(alignment) + 1) * subsample * second_per_frame
|
||||||
|
text_grid.generate_textgrid(
|
||||||
|
maxtime=second_per_example,
|
||||||
|
intervals=tierformat,
|
||||||
|
output=textgrid_path)
|
||||||
|
|
||||||
|
def run_align(self):
|
||||||
|
self.resume_or_scratch()
|
||||||
|
try:
|
||||||
|
self.align()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
def load_inferspec(self):
|
||||||
|
"""infer model and input spec.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Layer: inference model
|
||||||
|
List[paddle.static.InputSpec]: input spec.
|
||||||
|
"""
|
||||||
|
from deepspeech.models.u2 import U2InferModel
|
||||||
|
infer_model = U2InferModel.from_pretrained(self.test_loader,
|
||||||
|
self.config.model.clone(),
|
||||||
|
self.args.checkpoint_path)
|
||||||
|
feat_dim = self.test_loader.collate_fn.feature_size
|
||||||
|
input_spec = [
|
||||||
|
paddle.static.InputSpec(shape=[1, None, feat_dim],
|
||||||
|
dtype='float32'), # audio, [B,T,D]
|
||||||
|
paddle.static.InputSpec(shape=[1],
|
||||||
|
dtype='int64'), # audio_length, [B]
|
||||||
|
]
|
||||||
|
return infer_model, input_spec
|
||||||
|
|
||||||
|
def export(self):
|
||||||
|
infer_model, input_spec = self.load_inferspec()
|
||||||
|
assert isinstance(input_spec, list), type(input_spec)
|
||||||
|
infer_model.eval()
|
||||||
|
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
|
||||||
|
logger.info(f"Export code: {static_model.forward.code}")
|
||||||
|
paddle.jit.save(static_model, self.args.export_path)
|
||||||
|
|
||||||
|
def run_export(self):
|
||||||
|
try:
|
||||||
|
self.export()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""Setup the experiment.
|
||||||
|
"""
|
||||||
|
paddle.set_device(self.args.device)
|
||||||
|
|
||||||
|
self.setup_output_dir()
|
||||||
|
self.setup_checkpointer()
|
||||||
|
|
||||||
|
self.setup_dataloader()
|
||||||
|
self.setup_model()
|
||||||
|
|
||||||
|
self.iteration = 0
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
def setup_output_dir(self):
|
||||||
|
"""Create a directory used for output.
|
||||||
|
"""
|
||||||
|
# output dir
|
||||||
|
if self.args.output:
|
||||||
|
output_dir = Path(self.args.output).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
output_dir = Path(
|
||||||
|
self.args.checkpoint_path).expanduser().parent.parent
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.output_dir = output_dir
|
@ -0,0 +1,666 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import kaldiio
|
||||||
|
import numpy as np
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
||||||
|
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
|
||||||
|
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from deepspeech.frontend.normalizer import FeatureNormalizer
|
||||||
|
from deepspeech.frontend.speech import SpeechSegment
|
||||||
|
from deepspeech.frontend.utility import IGNORE_ID
|
||||||
|
from deepspeech.io.utility import pad_sequence
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"]
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
# namedtupe need global for pickle.
|
||||||
|
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechCollator():
|
||||||
|
@classmethod
|
||||||
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||||
|
default = CfgNode(
|
||||||
|
dict(
|
||||||
|
augmentation_config="",
|
||||||
|
random_seed=0,
|
||||||
|
mean_std_filepath="",
|
||||||
|
unit_type="char",
|
||||||
|
vocab_filepath="",
|
||||||
|
spm_model_prefix="",
|
||||||
|
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
||||||
|
feat_dim=0, # 'mfcc', 'fbank'
|
||||||
|
delta_delta=False, # 'mfcc', 'fbank'
|
||||||
|
stride_ms=10.0, # ms
|
||||||
|
window_ms=20.0, # ms
|
||||||
|
n_fft=None, # fft points
|
||||||
|
max_freq=None, # None for samplerate/2
|
||||||
|
target_sample_rate=16000, # target sample rate
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20,
|
||||||
|
dither=1.0, # feature dither
|
||||||
|
keep_transcription_text=False))
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config.merge_from_other_cfg(default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
"""Build a SpeechCollator object from a config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (yacs.config.CfgNode): configs object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SpeechCollator: collator object.
|
||||||
|
"""
|
||||||
|
assert 'augmentation_config' in config.collator
|
||||||
|
assert 'keep_transcription_text' in config.collator
|
||||||
|
assert 'mean_std_filepath' in config.collator
|
||||||
|
assert 'vocab_filepath' in config.collator
|
||||||
|
assert 'specgram_type' in config.collator
|
||||||
|
assert 'n_fft' in config.collator
|
||||||
|
assert config.collator
|
||||||
|
|
||||||
|
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
||||||
|
if config.collator.augmentation_config:
|
||||||
|
aug_file = io.open(
|
||||||
|
config.collator.augmentation_config,
|
||||||
|
mode='r',
|
||||||
|
encoding='utf8')
|
||||||
|
else:
|
||||||
|
aug_file = io.StringIO(initial_value='{}', newline='')
|
||||||
|
else:
|
||||||
|
aug_file = config.collator.augmentation_config
|
||||||
|
assert isinstance(aug_file, io.StringIO)
|
||||||
|
|
||||||
|
speech_collator = cls(
|
||||||
|
aug_file=aug_file,
|
||||||
|
random_seed=0,
|
||||||
|
mean_std_filepath=config.collator.mean_std_filepath,
|
||||||
|
unit_type=config.collator.unit_type,
|
||||||
|
vocab_filepath=config.collator.vocab_filepath,
|
||||||
|
spm_model_prefix=config.collator.spm_model_prefix,
|
||||||
|
specgram_type=config.collator.specgram_type,
|
||||||
|
feat_dim=config.collator.feat_dim,
|
||||||
|
delta_delta=config.collator.delta_delta,
|
||||||
|
stride_ms=config.collator.stride_ms,
|
||||||
|
window_ms=config.collator.window_ms,
|
||||||
|
n_fft=config.collator.n_fft,
|
||||||
|
max_freq=config.collator.max_freq,
|
||||||
|
target_sample_rate=config.collator.target_sample_rate,
|
||||||
|
use_dB_normalization=config.collator.use_dB_normalization,
|
||||||
|
target_dB=config.collator.target_dB,
|
||||||
|
dither=config.collator.dither,
|
||||||
|
keep_transcription_text=config.collator.keep_transcription_text)
|
||||||
|
return speech_collator
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
aug_file,
|
||||||
|
mean_std_filepath,
|
||||||
|
vocab_filepath,
|
||||||
|
spm_model_prefix,
|
||||||
|
random_seed=0,
|
||||||
|
unit_type="char",
|
||||||
|
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
||||||
|
feat_dim=0, # 'mfcc', 'fbank'
|
||||||
|
delta_delta=False, # 'mfcc', 'fbank'
|
||||||
|
stride_ms=10.0, # ms
|
||||||
|
window_ms=20.0, # ms
|
||||||
|
n_fft=None, # fft points
|
||||||
|
max_freq=None, # None for samplerate/2
|
||||||
|
target_sample_rate=16000, # target sample rate
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20,
|
||||||
|
dither=1.0,
|
||||||
|
keep_transcription_text=True):
|
||||||
|
"""SpeechCollator Collator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unit_type(str): token unit type, e.g. char, word, spm
|
||||||
|
vocab_filepath (str): vocab file path.
|
||||||
|
mean_std_filepath (str): mean and std file path, which suffix is *.npy
|
||||||
|
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
||||||
|
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
||||||
|
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
|
||||||
|
window_ms (float, optional): window size in ms. Defaults to 20.0.
|
||||||
|
n_fft (int, optional): fft points for rfft. Defaults to None.
|
||||||
|
max_freq (int, optional): max cut freq. Defaults to None.
|
||||||
|
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
|
||||||
|
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
|
||||||
|
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
|
||||||
|
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
|
||||||
|
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
|
||||||
|
target_dB (int, optional): target dB. Defaults to -20.
|
||||||
|
random_seed (int, optional): for random generator. Defaults to 0.
|
||||||
|
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
||||||
|
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
||||||
|
|
||||||
|
Do augmentations
|
||||||
|
Padding audio features with zeros to make them have the same shape (or
|
||||||
|
a user-defined shape) within one batch.
|
||||||
|
"""
|
||||||
|
self._keep_transcription_text = keep_transcription_text
|
||||||
|
|
||||||
|
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
||||||
|
self._augmentation_pipeline = AugmentationPipeline(
|
||||||
|
augmentation_config=aug_file.read(), random_seed=random_seed)
|
||||||
|
|
||||||
|
self._normalizer = FeatureNormalizer(
|
||||||
|
mean_std_filepath) if mean_std_filepath else None
|
||||||
|
|
||||||
|
self._stride_ms = stride_ms
|
||||||
|
self._target_sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
self._speech_featurizer = SpeechFeaturizer(
|
||||||
|
unit_type=unit_type,
|
||||||
|
vocab_filepath=vocab_filepath,
|
||||||
|
spm_model_prefix=spm_model_prefix,
|
||||||
|
specgram_type=specgram_type,
|
||||||
|
feat_dim=feat_dim,
|
||||||
|
delta_delta=delta_delta,
|
||||||
|
stride_ms=stride_ms,
|
||||||
|
window_ms=window_ms,
|
||||||
|
n_fft=n_fft,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
use_dB_normalization=use_dB_normalization,
|
||||||
|
target_dB=target_dB,
|
||||||
|
dither=dither)
|
||||||
|
|
||||||
|
def _parse_tar(self, file):
|
||||||
|
"""Parse a tar file to get a tarfile object
|
||||||
|
and a map containing tarinfoes
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
f = tarfile.open(file)
|
||||||
|
for tarinfo in f.getmembers():
|
||||||
|
result[tarinfo.name] = tarinfo
|
||||||
|
return f, result
|
||||||
|
|
||||||
|
def _subfile_from_tar(self, file):
|
||||||
|
"""Get subfile object from tar.
|
||||||
|
|
||||||
|
It will return a subfile object from tar file
|
||||||
|
and cached tar file info for next reading request.
|
||||||
|
"""
|
||||||
|
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||||
|
if 'tar2info' not in self._local_data.__dict__:
|
||||||
|
self._local_data.tar2info = {}
|
||||||
|
if 'tar2object' not in self._local_data.__dict__:
|
||||||
|
self._local_data.tar2object = {}
|
||||||
|
if tarpath not in self._local_data.tar2info:
|
||||||
|
object, infoes = self._parse_tar(tarpath)
|
||||||
|
self._local_data.tar2info[tarpath] = infoes
|
||||||
|
self._local_data.tar2object[tarpath] = object
|
||||||
|
return self._local_data.tar2object[tarpath].extractfile(
|
||||||
|
self._local_data.tar2info[tarpath][filename])
|
||||||
|
|
||||||
|
def process_utterance(self, audio_file, translation):
|
||||||
|
"""Load, augment, featurize and normalize for speech data.
|
||||||
|
|
||||||
|
:param audio_file: Filepath or file object of audio file.
|
||||||
|
:type audio_file: str | file
|
||||||
|
:param translation: translation text.
|
||||||
|
:type translation: str
|
||||||
|
:return: Tuple of audio feature tensor and data of translation part,
|
||||||
|
where translation part could be token ids or text.
|
||||||
|
:rtype: tuple of (2darray, list)
|
||||||
|
"""
|
||||||
|
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
||||||
|
speech_segment = SpeechSegment.from_file(
|
||||||
|
self._subfile_from_tar(audio_file), translation)
|
||||||
|
else:
|
||||||
|
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
||||||
|
|
||||||
|
# audio augment
|
||||||
|
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||||
|
|
||||||
|
specgram, translation_part = self._speech_featurizer.featurize(
|
||||||
|
speech_segment, self._keep_transcription_text)
|
||||||
|
if self._normalizer:
|
||||||
|
specgram = self._normalizer.apply(specgram)
|
||||||
|
|
||||||
|
# specgram augment
|
||||||
|
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
return specgram, translation_part
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""batch examples
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ([List]): batch is (audio, text)
|
||||||
|
audio (np.ndarray) shape (D, T)
|
||||||
|
text (List[int] or str): shape (U,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||||
|
audio : (B, Tmax, D)
|
||||||
|
audio_lens: (B)
|
||||||
|
text : (B, Umax)
|
||||||
|
text_lens: (B)
|
||||||
|
"""
|
||||||
|
audios = []
|
||||||
|
audio_lens = []
|
||||||
|
texts = []
|
||||||
|
text_lens = []
|
||||||
|
utts = []
|
||||||
|
for utt, audio, text in batch:
|
||||||
|
audio, text = self.process_utterance(audio, text)
|
||||||
|
#utt
|
||||||
|
utts.append(utt)
|
||||||
|
# audio
|
||||||
|
audios.append(audio) # [T, D]
|
||||||
|
audio_lens.append(audio.shape[0])
|
||||||
|
# text
|
||||||
|
# for training, text is token ids
|
||||||
|
# else text is string, convert to unicode ord
|
||||||
|
tokens = []
|
||||||
|
if self._keep_transcription_text:
|
||||||
|
assert isinstance(text, str), (type(text), text)
|
||||||
|
tokens = [ord(t) for t in text]
|
||||||
|
else:
|
||||||
|
tokens = text # token ids
|
||||||
|
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
|
||||||
|
tokens, dtype=np.int64)
|
||||||
|
texts.append(tokens)
|
||||||
|
text_lens.append(tokens.shape[0])
|
||||||
|
|
||||||
|
padded_audios = pad_sequence(
|
||||||
|
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||||
|
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||||
|
padded_texts = pad_sequence(
|
||||||
|
texts, padding_value=IGNORE_ID).astype(np.int64)
|
||||||
|
text_lens = np.array(text_lens).astype(np.int64)
|
||||||
|
return utts, padded_audios, audio_lens, padded_texts, text_lens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def manifest(self):
|
||||||
|
return self._manifest
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self._speech_featurizer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_list(self):
|
||||||
|
return self._speech_featurizer.vocab_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_dict(self):
|
||||||
|
return self._speech_featurizer.vocab_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_feature(self):
|
||||||
|
return self._speech_featurizer.text_feature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_size(self):
|
||||||
|
return self._speech_featurizer.feature_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride_ms(self):
|
||||||
|
return self._speech_featurizer.stride_ms
|
||||||
|
|
||||||
|
|
||||||
|
class TripletSpeechCollator(SpeechCollator):
|
||||||
|
def process_utterance(self, audio_file, translation, transcript):
|
||||||
|
"""Load, augment, featurize and normalize for speech data.
|
||||||
|
|
||||||
|
:param audio_file: Filepath or file object of audio file.
|
||||||
|
:type audio_file: str | file
|
||||||
|
:param translation: translation text.
|
||||||
|
:type translation: str
|
||||||
|
:return: Tuple of audio feature tensor and data of translation part,
|
||||||
|
where translation part could be token ids or text.
|
||||||
|
:rtype: tuple of (2darray, list)
|
||||||
|
"""
|
||||||
|
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
||||||
|
speech_segment = SpeechSegment.from_file(
|
||||||
|
self._subfile_from_tar(audio_file), translation)
|
||||||
|
else:
|
||||||
|
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
||||||
|
|
||||||
|
# audio augment
|
||||||
|
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||||
|
|
||||||
|
specgram, translation_part = self._speech_featurizer.featurize(
|
||||||
|
speech_segment, self._keep_transcription_text)
|
||||||
|
transcript_part = self._speech_featurizer._text_featurizer.featurize(
|
||||||
|
transcript)
|
||||||
|
if self._normalizer:
|
||||||
|
specgram = self._normalizer.apply(specgram)
|
||||||
|
|
||||||
|
# specgram augment
|
||||||
|
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
return specgram, translation_part, transcript_part
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""batch examples
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ([List]): batch is (audio, text)
|
||||||
|
audio (np.ndarray) shape (D, T)
|
||||||
|
text (List[int] or str): shape (U,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||||
|
audio : (B, Tmax, D)
|
||||||
|
audio_lens: (B)
|
||||||
|
text : (B, Umax)
|
||||||
|
text_lens: (B)
|
||||||
|
"""
|
||||||
|
audios = []
|
||||||
|
audio_lens = []
|
||||||
|
translation_text = []
|
||||||
|
translation_text_lens = []
|
||||||
|
transcription_text = []
|
||||||
|
transcription_text_lens = []
|
||||||
|
|
||||||
|
utts = []
|
||||||
|
for utt, audio, translation, transcription in batch:
|
||||||
|
audio, translation, transcription = self.process_utterance(
|
||||||
|
audio, translation, transcription)
|
||||||
|
#utt
|
||||||
|
utts.append(utt)
|
||||||
|
# audio
|
||||||
|
audios.append(audio) # [T, D]
|
||||||
|
audio_lens.append(audio.shape[0])
|
||||||
|
# text
|
||||||
|
# for training, text is token ids
|
||||||
|
# else text is string, convert to unicode ord
|
||||||
|
tokens = [[], []]
|
||||||
|
for idx, text in enumerate([translation, transcription]):
|
||||||
|
if self._keep_transcription_text:
|
||||||
|
assert isinstance(text, str), (type(text), text)
|
||||||
|
tokens[idx] = [ord(t) for t in text]
|
||||||
|
else:
|
||||||
|
tokens[idx] = text # token ids
|
||||||
|
tokens[idx] = tokens[idx] if isinstance(
|
||||||
|
tokens[idx], np.ndarray) else np.array(
|
||||||
|
tokens[idx], dtype=np.int64)
|
||||||
|
translation_text.append(tokens[0])
|
||||||
|
translation_text_lens.append(tokens[0].shape[0])
|
||||||
|
transcription_text.append(tokens[1])
|
||||||
|
transcription_text_lens.append(tokens[1].shape[0])
|
||||||
|
|
||||||
|
padded_audios = pad_sequence(
|
||||||
|
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||||
|
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||||
|
padded_translation = pad_sequence(
|
||||||
|
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||||
|
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
||||||
|
padded_transcription = pad_sequence(
|
||||||
|
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||||
|
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
||||||
|
return utts, padded_audios, audio_lens, (
|
||||||
|
padded_translation, padded_transcription), (translation_lens,
|
||||||
|
transcription_lens)
|
||||||
|
|
||||||
|
|
||||||
|
class KaldiPrePorocessedCollator(SpeechCollator):
|
||||||
|
@classmethod
|
||||||
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||||
|
default = CfgNode(
|
||||||
|
dict(
|
||||||
|
augmentation_config="",
|
||||||
|
random_seed=0,
|
||||||
|
unit_type="char",
|
||||||
|
vocab_filepath="",
|
||||||
|
spm_model_prefix="",
|
||||||
|
feat_dim=0,
|
||||||
|
stride_ms=10.0,
|
||||||
|
keep_transcription_text=False))
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config.merge_from_other_cfg(default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
"""Build a SpeechCollator object from a config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (yacs.config.CfgNode): configs object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SpeechCollator: collator object.
|
||||||
|
"""
|
||||||
|
assert 'augmentation_config' in config.collator
|
||||||
|
assert 'keep_transcription_text' in config.collator
|
||||||
|
assert 'vocab_filepath' in config.collator
|
||||||
|
assert config.collator
|
||||||
|
|
||||||
|
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
||||||
|
if config.collator.augmentation_config:
|
||||||
|
aug_file = io.open(
|
||||||
|
config.collator.augmentation_config,
|
||||||
|
mode='r',
|
||||||
|
encoding='utf8')
|
||||||
|
else:
|
||||||
|
aug_file = io.StringIO(initial_value='{}', newline='')
|
||||||
|
else:
|
||||||
|
aug_file = config.collator.augmentation_config
|
||||||
|
assert isinstance(aug_file, io.StringIO)
|
||||||
|
|
||||||
|
speech_collator = cls(
|
||||||
|
aug_file=aug_file,
|
||||||
|
random_seed=0,
|
||||||
|
unit_type=config.collator.unit_type,
|
||||||
|
vocab_filepath=config.collator.vocab_filepath,
|
||||||
|
spm_model_prefix=config.collator.spm_model_prefix,
|
||||||
|
feat_dim=config.collator.feat_dim,
|
||||||
|
stride_ms=config.collator.stride_ms,
|
||||||
|
keep_transcription_text=config.collator.keep_transcription_text)
|
||||||
|
return speech_collator
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
aug_file,
|
||||||
|
vocab_filepath,
|
||||||
|
spm_model_prefix,
|
||||||
|
random_seed=0,
|
||||||
|
unit_type="char",
|
||||||
|
feat_dim=0,
|
||||||
|
stride_ms=10.0,
|
||||||
|
keep_transcription_text=True):
|
||||||
|
"""SpeechCollator Collator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unit_type(str): token unit type, e.g. char, word, spm
|
||||||
|
vocab_filepath (str): vocab file path.
|
||||||
|
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
||||||
|
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
||||||
|
random_seed (int, optional): for random generator. Defaults to 0.
|
||||||
|
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
||||||
|
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
||||||
|
|
||||||
|
Do augmentations
|
||||||
|
Padding audio features with zeros to make them have the same shape (or
|
||||||
|
a user-defined shape) within one batch.
|
||||||
|
"""
|
||||||
|
self._keep_transcription_text = keep_transcription_text
|
||||||
|
self._feat_dim = feat_dim
|
||||||
|
self._stride_ms = stride_ms
|
||||||
|
|
||||||
|
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
||||||
|
self._augmentation_pipeline = AugmentationPipeline(
|
||||||
|
augmentation_config=aug_file.read(), random_seed=random_seed)
|
||||||
|
|
||||||
|
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
|
||||||
|
spm_model_prefix)
|
||||||
|
|
||||||
|
def process_utterance(self, audio_file, translation):
|
||||||
|
"""Load, augment, featurize and normalize for speech data.
|
||||||
|
|
||||||
|
:param audio_file: Filepath or file object of kaldi processed feature.
|
||||||
|
:type audio_file: str | file
|
||||||
|
:param translation: Translation text.
|
||||||
|
:type translation: str
|
||||||
|
:return: Tuple of audio feature tensor and data of translation part,
|
||||||
|
where translation part could be token ids or text.
|
||||||
|
:rtype: tuple of (2darray, list)
|
||||||
|
"""
|
||||||
|
specgram = kaldiio.load_mat(audio_file)
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
assert specgram.shape[
|
||||||
|
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
||||||
|
self._feat_dim, specgram.shape[0])
|
||||||
|
|
||||||
|
# specgram augment
|
||||||
|
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||||
|
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
if self._keep_transcription_text:
|
||||||
|
return specgram, translation
|
||||||
|
else:
|
||||||
|
text_ids = self._text_featurizer.featurize(translation)
|
||||||
|
return specgram, text_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def manifest(self):
|
||||||
|
return self._manifest
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self._text_featurizer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_list(self):
|
||||||
|
return self._text_featurizer.vocab_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_dict(self):
|
||||||
|
return self._text_featurizer.vocab_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_feature(self):
|
||||||
|
return self._text_featurizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_size(self):
|
||||||
|
return self._feat_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride_ms(self):
|
||||||
|
return self._stride_ms
|
||||||
|
|
||||||
|
|
||||||
|
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
|
||||||
|
def process_utterance(self, audio_file, translation, transcript):
|
||||||
|
"""Load, augment, featurize and normalize for speech data.
|
||||||
|
|
||||||
|
:param audio_file: Filepath or file object of kali processed feature.
|
||||||
|
:type audio_file: str | file
|
||||||
|
:param translation: Translation text.
|
||||||
|
:type translation: str
|
||||||
|
:param transcript: Transcription text.
|
||||||
|
:type transcript: str
|
||||||
|
:return: Tuple of audio feature tensor and data of translation and transcription parts,
|
||||||
|
where translation and transcription parts could be token ids or text.
|
||||||
|
:rtype: tuple of (2darray, (list, list))
|
||||||
|
"""
|
||||||
|
specgram = kaldiio.load_mat(audio_file)
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
assert specgram.shape[
|
||||||
|
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
||||||
|
self._feat_dim, specgram.shape[0])
|
||||||
|
|
||||||
|
# specgram augment
|
||||||
|
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||||
|
|
||||||
|
specgram = specgram.transpose([1, 0])
|
||||||
|
if self._keep_transcription_text:
|
||||||
|
return specgram, translation, transcript
|
||||||
|
else:
|
||||||
|
translation_text_ids = self._text_featurizer.featurize(translation)
|
||||||
|
transcript_text_ids = self._text_featurizer.featurize(transcript)
|
||||||
|
return specgram, translation_text_ids, transcript_text_ids
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""batch examples
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ([List]): batch is (audio, text)
|
||||||
|
audio (np.ndarray) shape (D, T)
|
||||||
|
translation (List[int] or str): shape (U,)
|
||||||
|
transcription (List[int] or str): shape (V,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||||
|
audio : (B, Tmax, D)
|
||||||
|
audio_lens: (B)
|
||||||
|
translation_text : (B, Umax)
|
||||||
|
translation_text_lens: (B)
|
||||||
|
transcription_text : (B, Vmax)
|
||||||
|
transcription_text_lens: (B)
|
||||||
|
"""
|
||||||
|
audios = []
|
||||||
|
audio_lens = []
|
||||||
|
translation_text = []
|
||||||
|
translation_text_lens = []
|
||||||
|
transcription_text = []
|
||||||
|
transcription_text_lens = []
|
||||||
|
|
||||||
|
utts = []
|
||||||
|
for utt, audio, translation, transcription in batch:
|
||||||
|
audio, translation, transcription = self.process_utterance(
|
||||||
|
audio, translation, transcription)
|
||||||
|
#utt
|
||||||
|
utts.append(utt)
|
||||||
|
# audio
|
||||||
|
audios.append(audio) # [T, D]
|
||||||
|
audio_lens.append(audio.shape[0])
|
||||||
|
# text
|
||||||
|
# for training, text is token ids
|
||||||
|
# else text is string, convert to unicode ord
|
||||||
|
tokens = [[], []]
|
||||||
|
for idx, text in enumerate([translation, transcription]):
|
||||||
|
if self._keep_transcription_text:
|
||||||
|
assert isinstance(text, str), (type(text), text)
|
||||||
|
tokens[idx] = [ord(t) for t in text]
|
||||||
|
else:
|
||||||
|
tokens[idx] = text # token ids
|
||||||
|
tokens[idx] = tokens[idx] if isinstance(
|
||||||
|
tokens[idx], np.ndarray) else np.array(
|
||||||
|
tokens[idx], dtype=np.int64)
|
||||||
|
translation_text.append(tokens[0])
|
||||||
|
translation_text_lens.append(tokens[0].shape[0])
|
||||||
|
transcription_text.append(tokens[1])
|
||||||
|
transcription_text_lens.append(tokens[1].shape[0])
|
||||||
|
|
||||||
|
padded_audios = pad_sequence(
|
||||||
|
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||||
|
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||||
|
padded_translation = pad_sequence(
|
||||||
|
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||||
|
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
||||||
|
padded_transcription = pad_sequence(
|
||||||
|
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||||
|
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
||||||
|
return utts, padded_audios, audio_lens, (
|
||||||
|
padded_translation, padded_transcription), (translation_lens,
|
||||||
|
transcription_lens)
|
@ -0,0 +1,734 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""U2 ASR Model
|
||||||
|
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
|
||||||
|
(https://arxiv.org/pdf/2012.05481.pdf)
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import jit
|
||||||
|
from paddle import nn
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from deepspeech.frontend.utility import IGNORE_ID
|
||||||
|
from deepspeech.frontend.utility import load_cmvn
|
||||||
|
from deepspeech.modules.cmvn import GlobalCMVN
|
||||||
|
from deepspeech.modules.ctc import CTCDecoder
|
||||||
|
from deepspeech.modules.decoder import TransformerDecoder
|
||||||
|
from deepspeech.modules.encoder import ConformerEncoder
|
||||||
|
from deepspeech.modules.encoder import TransformerEncoder
|
||||||
|
from deepspeech.modules.loss import LabelSmoothingLoss
|
||||||
|
from deepspeech.modules.mask import make_pad_mask
|
||||||
|
from deepspeech.modules.mask import mask_finished_preds
|
||||||
|
from deepspeech.modules.mask import mask_finished_scores
|
||||||
|
from deepspeech.modules.mask import subsequent_mask
|
||||||
|
from deepspeech.utils import checkpoint
|
||||||
|
from deepspeech.utils import layer_tools
|
||||||
|
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
from deepspeech.utils.tensor_utils import add_sos_eos
|
||||||
|
from deepspeech.utils.tensor_utils import pad_sequence
|
||||||
|
from deepspeech.utils.tensor_utils import th_accuracy
|
||||||
|
from deepspeech.utils.utility import log_add
|
||||||
|
|
||||||
|
__all__ = ["U2STModel", "U2STInferModel"]
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
class U2STBaseModel(nn.Module):
|
||||||
|
"""CTC-Attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||||
|
# network architecture
|
||||||
|
default = CfgNode()
|
||||||
|
# allow add new item when merge_with_file
|
||||||
|
default.cmvn_file = ""
|
||||||
|
default.cmvn_file_type = "json"
|
||||||
|
default.input_dim = 0
|
||||||
|
default.output_dim = 0
|
||||||
|
# encoder related
|
||||||
|
default.encoder = 'transformer'
|
||||||
|
default.encoder_conf = CfgNode(
|
||||||
|
dict(
|
||||||
|
output_size=256, # dimension of attention
|
||||||
|
attention_heads=4,
|
||||||
|
linear_units=2048, # the number of units of position-wise feed forward
|
||||||
|
num_blocks=12, # the number of encoder blocks
|
||||||
|
dropout_rate=0.1,
|
||||||
|
positional_dropout_rate=0.1,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||||
|
normalize_before=True,
|
||||||
|
# use_cnn_module=True,
|
||||||
|
# cnn_module_kernel=15,
|
||||||
|
# activation_type='swish',
|
||||||
|
# pos_enc_layer_type='rel_pos',
|
||||||
|
# selfattention_layer_type='rel_selfattn',
|
||||||
|
))
|
||||||
|
# decoder related
|
||||||
|
default.decoder = 'transformer'
|
||||||
|
default.decoder_conf = CfgNode(
|
||||||
|
dict(
|
||||||
|
attention_heads=4,
|
||||||
|
linear_units=2048,
|
||||||
|
num_blocks=6,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
positional_dropout_rate=0.1,
|
||||||
|
self_attention_dropout_rate=0.0,
|
||||||
|
src_attention_dropout_rate=0.0, ))
|
||||||
|
# hybrid CTC/attention
|
||||||
|
default.model_conf = CfgNode(
|
||||||
|
dict(
|
||||||
|
asr_weight=0.0,
|
||||||
|
ctc_weight=0.0,
|
||||||
|
lsm_weight=0.1, # label smoothing option
|
||||||
|
length_normalized_loss=False, ))
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
config.merge_from_other_cfg(default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size: int,
|
||||||
|
encoder: TransformerEncoder,
|
||||||
|
st_decoder: TransformerDecoder,
|
||||||
|
decoder: TransformerDecoder=None,
|
||||||
|
ctc: CTCDecoder=None,
|
||||||
|
ctc_weight: float=0.0,
|
||||||
|
asr_weight: float=0.0,
|
||||||
|
ignore_id: int=IGNORE_ID,
|
||||||
|
lsm_weight: float=0.0,
|
||||||
|
length_normalized_loss: bool=False):
|
||||||
|
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
# note that eos is the same as sos (equivalent ID)
|
||||||
|
self.sos = vocab_size - 1
|
||||||
|
self.eos = vocab_size - 1
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
self.ctc_weight = ctc_weight
|
||||||
|
self.asr_weight = asr_weight
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.st_decoder = st_decoder
|
||||||
|
self.decoder = decoder
|
||||||
|
self.ctc = ctc
|
||||||
|
self.criterion_att = LabelSmoothingLoss(
|
||||||
|
size=vocab_size,
|
||||||
|
padding_idx=ignore_id,
|
||||||
|
smoothing=lsm_weight,
|
||||||
|
normalize_length=length_normalized_loss, )
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
speech: paddle.Tensor,
|
||||||
|
speech_lengths: paddle.Tensor,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
asr_text: paddle.Tensor=None,
|
||||||
|
asr_text_lengths: paddle.Tensor=None,
|
||||||
|
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[
|
||||||
|
paddle.Tensor]]:
|
||||||
|
"""Frontend + Encoder + Decoder + Calc loss
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
text: (Batch, Length)
|
||||||
|
text_lengths: (Batch,)
|
||||||
|
Returns:
|
||||||
|
total_loss, attention_loss, ctc_loss
|
||||||
|
"""
|
||||||
|
assert text_lengths.dim() == 1, text_lengths.shape
|
||||||
|
# Check that batch_size is unified
|
||||||
|
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
|
||||||
|
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
|
||||||
|
text.shape, text_lengths.shape)
|
||||||
|
# 1. Encoder
|
||||||
|
start = time.time()
|
||||||
|
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
||||||
|
encoder_time = time.time() - start
|
||||||
|
#logger.debug(f"encoder time: {encoder_time}")
|
||||||
|
#TODO(Hui Zhang): sum not support bool type
|
||||||
|
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
|
||||||
|
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
|
||||||
|
1) #[B, 1, T] -> [B]
|
||||||
|
|
||||||
|
# 2a. ST-decoder branch
|
||||||
|
start = time.time()
|
||||||
|
loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text,
|
||||||
|
text_lengths)
|
||||||
|
decoder_time = time.time() - start
|
||||||
|
|
||||||
|
loss_asr_att = None
|
||||||
|
loss_asr_ctc = None
|
||||||
|
# 2b. ASR Attention-decoder branch
|
||||||
|
if self.asr_weight > 0.:
|
||||||
|
if self.ctc_weight != 1.0:
|
||||||
|
start = time.time()
|
||||||
|
loss_asr_att, acc_att = self._calc_att_loss(
|
||||||
|
encoder_out, encoder_mask, asr_text, asr_text_lengths)
|
||||||
|
decoder_time = time.time() - start
|
||||||
|
|
||||||
|
# 2c. CTC branch
|
||||||
|
if self.ctc_weight != 0.0:
|
||||||
|
start = time.time()
|
||||||
|
loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text,
|
||||||
|
asr_text_lengths)
|
||||||
|
ctc_time = time.time() - start
|
||||||
|
|
||||||
|
if loss_asr_ctc is None:
|
||||||
|
loss_asr = loss_asr_att
|
||||||
|
elif loss_asr_att is None:
|
||||||
|
loss_asr = loss_asr_ctc
|
||||||
|
else:
|
||||||
|
loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight
|
||||||
|
) * loss_asr_att
|
||||||
|
loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st
|
||||||
|
else:
|
||||||
|
loss = loss_st
|
||||||
|
return loss, loss_st, loss_asr_att, loss_asr_ctc
|
||||||
|
|
||||||
|
def _calc_st_loss(
|
||||||
|
self,
|
||||||
|
encoder_out: paddle.Tensor,
|
||||||
|
encoder_mask: paddle.Tensor,
|
||||||
|
ys_pad: paddle.Tensor,
|
||||||
|
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
|
||||||
|
"""Calc attention loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_out (paddle.Tensor): [B, Tmax, D]
|
||||||
|
encoder_mask (paddle.Tensor): [B, 1, Tmax]
|
||||||
|
ys_pad (paddle.Tensor): [B, Umax]
|
||||||
|
ys_pad_lens (paddle.Tensor): [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
|
||||||
|
"""
|
||||||
|
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
|
||||||
|
self.ignore_id)
|
||||||
|
ys_in_lens = ys_pad_lens + 1
|
||||||
|
|
||||||
|
# 1. Forward decoder
|
||||||
|
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
||||||
|
ys_in_lens)
|
||||||
|
|
||||||
|
# 2. Compute attention loss
|
||||||
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||||
|
acc_att = th_accuracy(
|
||||||
|
decoder_out.view(-1, self.vocab_size),
|
||||||
|
ys_out_pad,
|
||||||
|
ignore_label=self.ignore_id, )
|
||||||
|
return loss_att, acc_att
|
||||||
|
|
||||||
|
def _calc_att_loss(
|
||||||
|
self,
|
||||||
|
encoder_out: paddle.Tensor,
|
||||||
|
encoder_mask: paddle.Tensor,
|
||||||
|
ys_pad: paddle.Tensor,
|
||||||
|
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
|
||||||
|
"""Calc attention loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_out (paddle.Tensor): [B, Tmax, D]
|
||||||
|
encoder_mask (paddle.Tensor): [B, 1, Tmax]
|
||||||
|
ys_pad (paddle.Tensor): [B, Umax]
|
||||||
|
ys_pad_lens (paddle.Tensor): [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
|
||||||
|
"""
|
||||||
|
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
|
||||||
|
self.ignore_id)
|
||||||
|
ys_in_lens = ys_pad_lens + 1
|
||||||
|
|
||||||
|
# 1. Forward decoder
|
||||||
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
||||||
|
ys_in_lens)
|
||||||
|
|
||||||
|
# 2. Compute attention loss
|
||||||
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||||
|
acc_att = th_accuracy(
|
||||||
|
decoder_out.view(-1, self.vocab_size),
|
||||||
|
ys_out_pad,
|
||||||
|
ignore_label=self.ignore_id, )
|
||||||
|
return loss_att, acc_att
|
||||||
|
|
||||||
|
def _forward_encoder(
|
||||||
|
self,
|
||||||
|
speech: paddle.Tensor,
|
||||||
|
speech_lengths: paddle.Tensor,
|
||||||
|
decoding_chunk_size: int=-1,
|
||||||
|
num_decoding_left_chunks: int=-1,
|
||||||
|
simulate_streaming: bool=False,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Encoder pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech (paddle.Tensor): [B, Tmax, D]
|
||||||
|
speech_lengths (paddle.Tensor): [B]
|
||||||
|
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
|
||||||
|
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
|
||||||
|
simulate_streaming (bool, optional): streaming or not. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
encoder hiddens (B, Tmax, D),
|
||||||
|
encoder hiddens mask (B, 1, Tmax).
|
||||||
|
"""
|
||||||
|
# Let's assume B = batch_size
|
||||||
|
# 1. Encoder
|
||||||
|
if simulate_streaming and decoding_chunk_size > 0:
|
||||||
|
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
|
||||||
|
speech,
|
||||||
|
decoding_chunk_size=decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks=num_decoding_left_chunks
|
||||||
|
) # (B, maxlen, encoder_dim)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_mask = self.encoder(
|
||||||
|
speech,
|
||||||
|
speech_lengths,
|
||||||
|
decoding_chunk_size=decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks=num_decoding_left_chunks
|
||||||
|
) # (B, maxlen, encoder_dim)
|
||||||
|
return encoder_out, encoder_mask
|
||||||
|
|
||||||
|
def translate(
|
||||||
|
self,
|
||||||
|
speech: paddle.Tensor,
|
||||||
|
speech_lengths: paddle.Tensor,
|
||||||
|
beam_size: int=10,
|
||||||
|
decoding_chunk_size: int=-1,
|
||||||
|
num_decoding_left_chunks: int=-1,
|
||||||
|
simulate_streaming: bool=False, ) -> paddle.Tensor:
|
||||||
|
""" Apply beam search on attention decoder
|
||||||
|
Args:
|
||||||
|
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||||
|
speech_length (paddle.Tensor): (batch, )
|
||||||
|
beam_size (int): beam size for beam search
|
||||||
|
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||||
|
trained model.
|
||||||
|
<0: for decoding, use full chunk.
|
||||||
|
>0: for decoding, use fixed chunk size as set.
|
||||||
|
0: used for training, it's prohibited here
|
||||||
|
simulate_streaming (bool): whether do encoder forward in a
|
||||||
|
streaming fashion
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: decoding result, (batch, max_result_len)
|
||||||
|
"""
|
||||||
|
assert speech.shape[0] == speech_lengths.shape[0]
|
||||||
|
assert decoding_chunk_size != 0
|
||||||
|
device = speech.place
|
||||||
|
batch_size = speech.shape[0]
|
||||||
|
|
||||||
|
# Let's assume B = batch_size and N = beam_size
|
||||||
|
# 1. Encoder
|
||||||
|
encoder_out, encoder_mask = self._forward_encoder(
|
||||||
|
speech, speech_lengths, decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks,
|
||||||
|
simulate_streaming) # (B, maxlen, encoder_dim)
|
||||||
|
maxlen = encoder_out.size(1)
|
||||||
|
encoder_dim = encoder_out.size(2)
|
||||||
|
running_size = batch_size * beam_size
|
||||||
|
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
|
||||||
|
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
|
||||||
|
encoder_mask = encoder_mask.unsqueeze(1).repeat(
|
||||||
|
1, beam_size, 1, 1).view(running_size, 1,
|
||||||
|
maxlen) # (B*N, 1, max_len)
|
||||||
|
|
||||||
|
hyps = paddle.ones(
|
||||||
|
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
|
||||||
|
# log scale score
|
||||||
|
scores = paddle.to_tensor(
|
||||||
|
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
|
||||||
|
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
|
||||||
|
device) # (B*N, 1)
|
||||||
|
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
|
||||||
|
cache: Optional[List[paddle.Tensor]] = None
|
||||||
|
# 2. Decoder forward step by step
|
||||||
|
for i in range(1, maxlen + 1):
|
||||||
|
# Stop if all batch and all beam produce eos
|
||||||
|
# TODO(Hui Zhang): if end_flag.sum() == running_size:
|
||||||
|
if end_flag.cast(paddle.int64).sum() == running_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2.1 Forward decoder step
|
||||||
|
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
|
||||||
|
running_size, 1, 1).to(device) # (B*N, i, i)
|
||||||
|
# logp: (B*N, vocab)
|
||||||
|
logp, cache = self.st_decoder.forward_one_step(
|
||||||
|
encoder_out, encoder_mask, hyps, hyps_mask, cache)
|
||||||
|
|
||||||
|
# 2.2 First beam prune: select topk best prob at current time
|
||||||
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
||||||
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
||||||
|
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
||||||
|
|
||||||
|
# 2.3 Seconde beam prune: select topk score with history
|
||||||
|
scores = scores + top_k_logp # (B*N, N), broadcast add
|
||||||
|
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
|
||||||
|
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
|
||||||
|
scores = scores.view(-1, 1) # (B*N, 1)
|
||||||
|
|
||||||
|
# 2.4. Compute base index in top_k_index,
|
||||||
|
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
|
||||||
|
# then find offset_k_index in top_k_index
|
||||||
|
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
|
||||||
|
1, beam_size) # (B, N)
|
||||||
|
base_k_index = base_k_index * beam_size * beam_size
|
||||||
|
best_k_index = base_k_index.view(-1) + offset_k_index.view(
|
||||||
|
-1) # (B*N)
|
||||||
|
|
||||||
|
# 2.5 Update best hyps
|
||||||
|
best_k_pred = paddle.index_select(
|
||||||
|
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
|
||||||
|
best_hyps_index = best_k_index // beam_size
|
||||||
|
last_best_k_hyps = paddle.index_select(
|
||||||
|
hyps, index=best_hyps_index, axis=0) # (B*N, i)
|
||||||
|
hyps = paddle.cat(
|
||||||
|
(last_best_k_hyps, best_k_pred.view(-1, 1)),
|
||||||
|
dim=1) # (B*N, i+1)
|
||||||
|
|
||||||
|
# 2.6 Update end flag
|
||||||
|
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
|
||||||
|
|
||||||
|
# 3. Select best of best
|
||||||
|
scores = scores.view(batch_size, beam_size)
|
||||||
|
# TODO: length normalization
|
||||||
|
best_index = paddle.argmax(scores, axis=-1).long() # (B)
|
||||||
|
best_hyps_index = best_index + paddle.arange(
|
||||||
|
batch_size, dtype=paddle.long) * beam_size
|
||||||
|
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
|
||||||
|
best_hyps = best_hyps[:, 1:]
|
||||||
|
return best_hyps
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def subsampling_rate(self) -> int:
|
||||||
|
""" Export interface for c++ call, return subsampling_rate of the
|
||||||
|
model
|
||||||
|
"""
|
||||||
|
return self.encoder.embed.subsampling_rate
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def right_context(self) -> int:
|
||||||
|
""" Export interface for c++ call, return right_context of the model
|
||||||
|
"""
|
||||||
|
return self.encoder.embed.right_context
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def sos_symbol(self) -> int:
|
||||||
|
""" Export interface for c++ call, return sos symbol id of the model
|
||||||
|
"""
|
||||||
|
return self.sos
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def eos_symbol(self) -> int:
|
||||||
|
""" Export interface for c++ call, return eos symbol id of the model
|
||||||
|
"""
|
||||||
|
return self.eos
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def forward_encoder_chunk(
|
||||||
|
self,
|
||||||
|
xs: paddle.Tensor,
|
||||||
|
offset: int,
|
||||||
|
required_cache_size: int,
|
||||||
|
subsampling_cache: Optional[paddle.Tensor]=None,
|
||||||
|
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
|
||||||
|
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
|
||||||
|
paddle.Tensor]]:
|
||||||
|
""" Export interface for c++ call, give input chunk xs, and return
|
||||||
|
output from time 0 to current chunk.
|
||||||
|
Args:
|
||||||
|
xs (paddle.Tensor): chunk input
|
||||||
|
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
|
||||||
|
elayers_output_cache (Optional[List[paddle.Tensor]]):
|
||||||
|
transformer/conformer encoder layers output cache
|
||||||
|
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
|
||||||
|
cnn cache
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: output, it ranges from time 0 to current chunk.
|
||||||
|
paddle.Tensor: subsampling cache
|
||||||
|
List[paddle.Tensor]: attention cache
|
||||||
|
List[paddle.Tensor]: conformer cnn cache
|
||||||
|
"""
|
||||||
|
return self.encoder.forward_chunk(
|
||||||
|
xs, offset, required_cache_size, subsampling_cache,
|
||||||
|
elayers_output_cache, conformer_cnn_cache)
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
""" Export interface for c++ call, apply linear transform and log
|
||||||
|
softmax before ctc
|
||||||
|
Args:
|
||||||
|
xs (paddle.Tensor): encoder output
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: activation before ctc
|
||||||
|
"""
|
||||||
|
return self.ctc.log_softmax(xs)
|
||||||
|
|
||||||
|
@jit.export
|
||||||
|
def forward_attention_decoder(
|
||||||
|
self,
|
||||||
|
hyps: paddle.Tensor,
|
||||||
|
hyps_lens: paddle.Tensor,
|
||||||
|
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
|
||||||
|
""" Export interface for c++ call, forward decoder with multiple
|
||||||
|
hypothesis from ctc prefix beam search and one encoder output
|
||||||
|
Args:
|
||||||
|
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
|
||||||
|
pad sos at the begining, (B, T)
|
||||||
|
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
|
||||||
|
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: decoder output, (B, L)
|
||||||
|
"""
|
||||||
|
assert encoder_out.size(0) == 1
|
||||||
|
num_hyps = hyps.size(0)
|
||||||
|
assert hyps_lens.size(0) == num_hyps
|
||||||
|
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
|
||||||
|
# (B, 1, T)
|
||||||
|
encoder_mask = paddle.ones(
|
||||||
|
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
|
||||||
|
# (num_hyps, max_hyps_len, vocab_size)
|
||||||
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
||||||
|
hyps_lens)
|
||||||
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def decode(self,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
feats_lengths: paddle.Tensor,
|
||||||
|
text_feature: Dict[str, int],
|
||||||
|
decoding_method: str,
|
||||||
|
lang_model_path: str,
|
||||||
|
beam_alpha: float,
|
||||||
|
beam_beta: float,
|
||||||
|
beam_size: int,
|
||||||
|
cutoff_prob: float,
|
||||||
|
cutoff_top_n: int,
|
||||||
|
num_processes: int,
|
||||||
|
ctc_weight: float=0.0,
|
||||||
|
decoding_chunk_size: int=-1,
|
||||||
|
num_decoding_left_chunks: int=-1,
|
||||||
|
simulate_streaming: bool=False):
|
||||||
|
"""u2 decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats (Tenosr): audio features, (B, T, D)
|
||||||
|
feats_lengths (Tenosr): (B)
|
||||||
|
text_feature (TextFeaturizer): text feature object.
|
||||||
|
decoding_method (str): decoding mode, e.g.
|
||||||
|
'fullsentence',
|
||||||
|
'simultaneous'
|
||||||
|
lang_model_path (str): lm path.
|
||||||
|
beam_alpha (float): lm weight.
|
||||||
|
beam_beta (float): length penalty.
|
||||||
|
beam_size (int): beam size for search
|
||||||
|
cutoff_prob (float): for prune.
|
||||||
|
cutoff_top_n (int): for prune.
|
||||||
|
num_processes (int):
|
||||||
|
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
|
||||||
|
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
|
||||||
|
<0: for decoding, use full chunk.
|
||||||
|
>0: for decoding, use fixed chunk size as set.
|
||||||
|
0: used for training, it's prohibited here.
|
||||||
|
num_decoding_left_chunks (int, optional):
|
||||||
|
number of left chunks for decoding. Defaults to -1.
|
||||||
|
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: when not support decoding_method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: transcripts.
|
||||||
|
"""
|
||||||
|
batch_size = feats.size(0)
|
||||||
|
|
||||||
|
if decoding_method == 'fullsentence':
|
||||||
|
hyps = self.translate(
|
||||||
|
feats,
|
||||||
|
feats_lengths,
|
||||||
|
beam_size=beam_size,
|
||||||
|
decoding_chunk_size=decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
||||||
|
simulate_streaming=simulate_streaming)
|
||||||
|
hyps = [hyp.tolist() for hyp in hyps]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
||||||
|
|
||||||
|
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class U2STModel(U2STBaseModel):
|
||||||
|
def __init__(self, configs: dict):
|
||||||
|
vocab_size, encoder, decoder = U2STModel._init_from_config(configs)
|
||||||
|
|
||||||
|
if isinstance(decoder, Tuple):
|
||||||
|
st_decoder, asr_decoder, ctc = decoder
|
||||||
|
super().__init__(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
encoder=encoder,
|
||||||
|
st_decoder=st_decoder,
|
||||||
|
decoder=asr_decoder,
|
||||||
|
ctc=ctc,
|
||||||
|
**configs['model_conf'])
|
||||||
|
else:
|
||||||
|
super().__init__(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
encoder=encoder,
|
||||||
|
st_decoder=decoder,
|
||||||
|
**configs['model_conf'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_from_config(cls, configs: dict):
|
||||||
|
"""init sub module for model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs (dict): config dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: raise when using not support encoder type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
|
||||||
|
"""
|
||||||
|
if configs['cmvn_file'] is not None:
|
||||||
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
||||||
|
configs['cmvn_file_type'])
|
||||||
|
global_cmvn = GlobalCMVN(
|
||||||
|
paddle.to_tensor(mean, dtype=paddle.float),
|
||||||
|
paddle.to_tensor(istd, dtype=paddle.float))
|
||||||
|
else:
|
||||||
|
global_cmvn = None
|
||||||
|
|
||||||
|
input_dim = configs['input_dim']
|
||||||
|
vocab_size = configs['output_dim']
|
||||||
|
assert input_dim != 0, input_dim
|
||||||
|
assert vocab_size != 0, vocab_size
|
||||||
|
|
||||||
|
encoder_type = configs.get('encoder', 'transformer')
|
||||||
|
logger.info(f"U2 Encoder type: {encoder_type}")
|
||||||
|
if encoder_type == 'transformer':
|
||||||
|
encoder = TransformerEncoder(
|
||||||
|
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
||||||
|
elif encoder_type == 'conformer':
|
||||||
|
encoder = ConformerEncoder(
|
||||||
|
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not support encoder type:{encoder_type}")
|
||||||
|
|
||||||
|
st_decoder = TransformerDecoder(vocab_size,
|
||||||
|
encoder.output_size(),
|
||||||
|
**configs['decoder_conf'])
|
||||||
|
|
||||||
|
asr_weight = configs['model_conf']['asr_weight']
|
||||||
|
logger.info(f"ASR Joint Training Weight: {asr_weight}")
|
||||||
|
|
||||||
|
if asr_weight > 0.:
|
||||||
|
decoder = TransformerDecoder(vocab_size,
|
||||||
|
encoder.output_size(),
|
||||||
|
**configs['decoder_conf'])
|
||||||
|
ctc = CTCDecoder(
|
||||||
|
odim=vocab_size,
|
||||||
|
enc_n_units=encoder.output_size(),
|
||||||
|
blank_id=0,
|
||||||
|
dropout_rate=0.0,
|
||||||
|
reduction=True, # sum
|
||||||
|
batch_average=True) # sum / batch_size
|
||||||
|
|
||||||
|
return vocab_size, encoder, (st_decoder, decoder, ctc)
|
||||||
|
else:
|
||||||
|
return vocab_size, encoder, st_decoder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, configs: dict):
|
||||||
|
"""init model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs (dict): config dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: raise when using not support encoder type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Layer: U2STModel
|
||||||
|
"""
|
||||||
|
model = cls(configs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, dataloader, config, checkpoint_path):
|
||||||
|
"""Build a DeepSpeech2Model model from a pretrained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader (paddle.io.DataLoader): not used.
|
||||||
|
config (yacs.config.CfgNode): model configs
|
||||||
|
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeepSpeech2Model: The model built from pretrained result.
|
||||||
|
"""
|
||||||
|
config.defrost()
|
||||||
|
config.input_dim = dataloader.collate_fn.feature_size
|
||||||
|
config.output_dim = dataloader.collate_fn.vocab_size
|
||||||
|
config.freeze()
|
||||||
|
model = cls.from_config(config)
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
infos = checkpoint.load_parameters(
|
||||||
|
model, checkpoint_path=checkpoint_path)
|
||||||
|
logger.info(f"checkpoint info: {infos}")
|
||||||
|
layer_tools.summary(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class U2STInferModel(U2STModel):
|
||||||
|
def __init__(self, configs: dict):
|
||||||
|
super().__init__(configs)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
feats,
|
||||||
|
feats_lengths,
|
||||||
|
decoding_chunk_size=-1,
|
||||||
|
num_decoding_left_chunks=-1,
|
||||||
|
simulate_streaming=False):
|
||||||
|
"""export model function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats (Tensor): [B, T, D]
|
||||||
|
feats_lengths (Tensor): [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: best path result
|
||||||
|
"""
|
||||||
|
return self.translate(
|
||||||
|
feats,
|
||||||
|
feats_lengths,
|
||||||
|
decoding_chunk_size=decoding_chunk_size,
|
||||||
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
||||||
|
simulate_streaming=simulate_streaming)
|
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""This module provides functions to calculate bleu score in different level.
|
||||||
|
e.g. wer for word-level, cer for char-level.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import sacrebleu
|
||||||
|
|
||||||
|
__all__ = ['bleu', 'char_bleu']
|
||||||
|
|
||||||
|
|
||||||
|
def bleu(hypothesis, reference):
|
||||||
|
"""Calculate BLEU. BLEU compares reference text and
|
||||||
|
hypothesis text in word-level using scarebleu.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
:param reference: The reference sentences.
|
||||||
|
:type reference: list[list[str]]
|
||||||
|
:param hypothesis: The hypothesis sentence.
|
||||||
|
:type hypothesis: list[str]
|
||||||
|
:raises ValueError: If the reference length is zero.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return sacrebleu.corpus_bleu(hypothesis, reference)
|
||||||
|
|
||||||
|
def char_bleu(hypothesis, reference):
|
||||||
|
"""Calculate BLEU. BLEU compares reference text and
|
||||||
|
hypothesis text in char-level using scarebleu.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
:param reference: The reference sentences.
|
||||||
|
:type reference: list[list[str]]
|
||||||
|
:param hypothesis: The hypothesis sentence.
|
||||||
|
:type hypothesis: list[str]
|
||||||
|
:raises ValueError: If the reference number is zero.
|
||||||
|
"""
|
||||||
|
hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
|
||||||
|
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ]
|
||||||
|
|
||||||
|
return sacrebleu.corpus_bleu(hypothesis, reference)
|
@ -0,0 +1,6 @@
|
|||||||
|
*.tar.gz.*
|
||||||
|
manifest.*
|
||||||
|
*.md
|
||||||
|
EN-ZH/
|
||||||
|
train-split/
|
||||||
|
test-segment/
|
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Prepare Ted-En-Zh speech translation dataset
|
||||||
|
|
||||||
|
Create manifest files from splited datased.
|
||||||
|
dev set: tst2010, test set: tst2015
|
||||||
|
Manifest file is a json-format file with each line containing the
|
||||||
|
meta data (i.e. audio filepath, transcript and audio duration)
|
||||||
|
of each audio file in the data set.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import codecs
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--src_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Directory to kaldi splited data. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_prefix",
|
||||||
|
default="manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path_prefix):
|
||||||
|
print("Creating manifest %s ..." % manifest_path_prefix)
|
||||||
|
json_lines = []
|
||||||
|
|
||||||
|
data_types_infos = [('train', 'train-split/train-segment', 'En-Zh/train.en-zh'),
|
||||||
|
('dev', 'test-segment/tst2010', 'En-Zh/tst2010.en-zh'),
|
||||||
|
('test', 'test-segment/tst2015', 'En-Zh/tst2015.en-zh')]
|
||||||
|
for data_info in data_types_infos:
|
||||||
|
dtype, audio_relative_dir, text_relative_path = data_info
|
||||||
|
del json_lines[:]
|
||||||
|
total_sec = 0.0
|
||||||
|
total_text = 0.0
|
||||||
|
total_num = 0
|
||||||
|
|
||||||
|
text_path = os.path.join(data_dir, text_relative_path)
|
||||||
|
audio_dir = os.path.join(data_dir, audio_relative_dir)
|
||||||
|
|
||||||
|
for line in codecs.open(text_path, 'r', 'utf-8', errors='ignore'):
|
||||||
|
line = line.strip()
|
||||||
|
if len(line) < 1:
|
||||||
|
continue
|
||||||
|
audio_id, trancription, translation = line.split('\t')
|
||||||
|
utt = audio_id.split('.')[0]
|
||||||
|
|
||||||
|
audio_path = os.path.join(audio_dir, audio_id)
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
if os.path.getsize(audio_path) < 30000:
|
||||||
|
continue
|
||||||
|
audio_data, samplerate = soundfile.read(audio_path)
|
||||||
|
duration = float(len(audio_data) / samplerate)
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'utt': utt,
|
||||||
|
'feat': audio_path,
|
||||||
|
'feat_shape': (duration, ), # second
|
||||||
|
'text': " ".join(translation.split()),
|
||||||
|
'text1': " ".join(trancription.split())
|
||||||
|
},
|
||||||
|
ensure_ascii=False))
|
||||||
|
|
||||||
|
total_sec += duration
|
||||||
|
total_text += len(translation.split())
|
||||||
|
total_num += 1
|
||||||
|
if not total_num % 1000:
|
||||||
|
print(dtype, 'Processed:', total_num)
|
||||||
|
|
||||||
|
manifest_path = manifest_path_prefix + '.' + dtype + '.raw'
|
||||||
|
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||||
|
for line in json_lines:
|
||||||
|
fout.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(src_dir, manifest_path=None):
|
||||||
|
"""create manifest file."""
|
||||||
|
if os.path.isdir(manifest_path):
|
||||||
|
manifest_path = os.path.join(manifest_path, 'manifest')
|
||||||
|
if manifest_path:
|
||||||
|
create_manifest(src_dir, manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if args.src_dir.startswith('~'):
|
||||||
|
args.src_dir = os.path.expanduser(args.src_dir)
|
||||||
|
|
||||||
|
prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix)
|
||||||
|
|
||||||
|
print("manifest prepare done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,109 @@
|
|||||||
|
# https://yaml.org/type/float.html
|
||||||
|
data:
|
||||||
|
train_manifest: data/manifest.train.tiny
|
||||||
|
dev_manifest: data/manifest.dev
|
||||||
|
test_manifest: data/manifest.test
|
||||||
|
min_input_len: 0.5 # second
|
||||||
|
max_input_len: 3000.0 # second
|
||||||
|
min_output_len: 0.0 # tokens
|
||||||
|
max_output_len: 400.0 # tokens
|
||||||
|
min_output_input_ratio: 0.01
|
||||||
|
max_output_input_ratio: 20.0
|
||||||
|
|
||||||
|
collator:
|
||||||
|
vocab_filepath: data/vocab.txt
|
||||||
|
unit_type: 'spm'
|
||||||
|
spm_model_prefix: data/bpe_unigram_8000
|
||||||
|
mean_std_filepath: ""
|
||||||
|
# augmentation_config: conf/augmentation.json
|
||||||
|
batch_size: 10
|
||||||
|
raw_wav: True # use raw_wav or kaldi feature
|
||||||
|
specgram_type: fbank #linear, mfcc, fbank
|
||||||
|
feat_dim: 80
|
||||||
|
delta_delta: False
|
||||||
|
dither: 1.0
|
||||||
|
target_sample_rate: 16000
|
||||||
|
max_freq: None
|
||||||
|
n_fft: None
|
||||||
|
stride_ms: 10.0
|
||||||
|
window_ms: 25.0
|
||||||
|
use_dB_normalization: True
|
||||||
|
target_dB: -20
|
||||||
|
random_seed: 0
|
||||||
|
keep_transcription_text: False
|
||||||
|
sortagrad: True
|
||||||
|
shuffle_method: batch_shuffle
|
||||||
|
num_workers: 2
|
||||||
|
|
||||||
|
|
||||||
|
# network architecture
|
||||||
|
model:
|
||||||
|
cmvn_file: "data/mean_std.json"
|
||||||
|
cmvn_file_type: "json"
|
||||||
|
# encoder related
|
||||||
|
encoder: transformer
|
||||||
|
encoder_conf:
|
||||||
|
output_size: 256 # dimension of attention
|
||||||
|
attention_heads: 4
|
||||||
|
linear_units: 2048 # the number of units of position-wise feed forward
|
||||||
|
num_blocks: 12 # the number of encoder blocks
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
attention_dropout_rate: 0.0
|
||||||
|
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||||
|
normalize_before: true
|
||||||
|
|
||||||
|
# decoder related
|
||||||
|
decoder: transformer
|
||||||
|
decoder_conf:
|
||||||
|
attention_heads: 4
|
||||||
|
linear_units: 2048
|
||||||
|
num_blocks: 6
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
self_attention_dropout_rate: 0.0
|
||||||
|
src_attention_dropout_rate: 0.0
|
||||||
|
|
||||||
|
# hybrid CTC/attention
|
||||||
|
model_conf:
|
||||||
|
asr_weight: 0.0
|
||||||
|
ctc_weight: 0.0
|
||||||
|
lsm_weight: 0.1 # label smoothing option
|
||||||
|
length_normalized_loss: false
|
||||||
|
|
||||||
|
|
||||||
|
training:
|
||||||
|
n_epoch: 120
|
||||||
|
accum_grad: 2
|
||||||
|
global_grad_clip: 5.0
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.004
|
||||||
|
weight_decay: 1e-06
|
||||||
|
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||||
|
scheduler_conf:
|
||||||
|
warmup_steps: 25000
|
||||||
|
lr_decay: 1.0
|
||||||
|
log_interval: 5
|
||||||
|
checkpoint:
|
||||||
|
kbest_n: 50
|
||||||
|
latest_n: 5
|
||||||
|
|
||||||
|
|
||||||
|
decoding:
|
||||||
|
batch_size: 5
|
||||||
|
error_rate_type: char-bleu
|
||||||
|
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
|
||||||
|
alpha: 2.5
|
||||||
|
beta: 0.3
|
||||||
|
beam_size: 10
|
||||||
|
cutoff_prob: 1.0
|
||||||
|
cutoff_top_n: 0
|
||||||
|
num_proc_bsearch: 8
|
||||||
|
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||||
|
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||||
|
# <0: for decoding, use full chunk.
|
||||||
|
# >0: for decoding, use fixed chunk size as set.
|
||||||
|
# 0: used for training, it's prohibited here.
|
||||||
|
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||||
|
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
@ -0,0 +1,111 @@
|
|||||||
|
# https://yaml.org/type/float.html
|
||||||
|
data:
|
||||||
|
train_manifest: data/manifest.train
|
||||||
|
dev_manifest: data/manifest.dev
|
||||||
|
test_manifest: data/manifest.test
|
||||||
|
min_input_len: 0.5 # second
|
||||||
|
max_input_len: 3000.0 # second
|
||||||
|
min_output_len: 0.0 # tokens
|
||||||
|
max_output_len: 400.0 # tokens
|
||||||
|
min_output_input_ratio: 0.01
|
||||||
|
max_output_input_ratio: 20.0
|
||||||
|
|
||||||
|
collator:
|
||||||
|
vocab_filepath: data/vocab.txt
|
||||||
|
unit_type: 'spm'
|
||||||
|
spm_model_prefix: data/bpe_unigram_8000
|
||||||
|
mean_std_filepath: ""
|
||||||
|
# augmentation_config: conf/augmentation.json
|
||||||
|
batch_size: 10
|
||||||
|
raw_wav: True # use raw_wav or kaldi feature
|
||||||
|
specgram_type: fbank #linear, mfcc, fbank
|
||||||
|
feat_dim: 80
|
||||||
|
delta_delta: False
|
||||||
|
dither: 1.0
|
||||||
|
target_sample_rate: 16000
|
||||||
|
max_freq: None
|
||||||
|
n_fft: None
|
||||||
|
stride_ms: 10.0
|
||||||
|
window_ms: 25.0
|
||||||
|
use_dB_normalization: True
|
||||||
|
target_dB: -20
|
||||||
|
random_seed: 0
|
||||||
|
keep_transcription_text: False
|
||||||
|
sortagrad: True
|
||||||
|
shuffle_method: batch_shuffle
|
||||||
|
num_workers: 2
|
||||||
|
|
||||||
|
|
||||||
|
# network architecture
|
||||||
|
model:
|
||||||
|
cmvn_file: "data/mean_std.json"
|
||||||
|
cmvn_file_type: "json"
|
||||||
|
# encoder related
|
||||||
|
encoder: transformer
|
||||||
|
encoder_conf:
|
||||||
|
output_size: 256 # dimension of attention
|
||||||
|
attention_heads: 4
|
||||||
|
linear_units: 2048 # the number of units of position-wise feed forward
|
||||||
|
num_blocks: 12 # the number of encoder blocks
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
attention_dropout_rate: 0.0
|
||||||
|
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||||
|
normalize_before: true
|
||||||
|
|
||||||
|
# decoder related
|
||||||
|
decoder: transformer
|
||||||
|
decoder_conf:
|
||||||
|
attention_heads: 4
|
||||||
|
linear_units: 2048
|
||||||
|
num_blocks: 6
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
self_attention_dropout_rate: 0.0
|
||||||
|
src_attention_dropout_rate: 0.0
|
||||||
|
|
||||||
|
# hybrid CTC/attention
|
||||||
|
model_conf:
|
||||||
|
asr_weight: 0.5
|
||||||
|
ctc_weight: 0.3
|
||||||
|
lsm_weight: 0.1 # label smoothing option
|
||||||
|
length_normalized_loss: false
|
||||||
|
|
||||||
|
|
||||||
|
training:
|
||||||
|
n_epoch: 120
|
||||||
|
accum_grad: 2
|
||||||
|
global_grad_clip: 5.0
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 2.5
|
||||||
|
weight_decay: 1e-06
|
||||||
|
scheduler: noam
|
||||||
|
scheduler_conf:
|
||||||
|
warmup_steps: 25000
|
||||||
|
lr_decay: 1.0
|
||||||
|
log_interval: 5
|
||||||
|
checkpoint:
|
||||||
|
kbest_n: 50
|
||||||
|
latest_n: 5
|
||||||
|
|
||||||
|
|
||||||
|
decoding:
|
||||||
|
batch_size: 5
|
||||||
|
error_rate_type: char-bleu
|
||||||
|
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
|
||||||
|
alpha: 2.5
|
||||||
|
beta: 0.3
|
||||||
|
beam_size: 10
|
||||||
|
cutoff_prob: 1.0
|
||||||
|
cutoff_top_n: 0
|
||||||
|
num_proc_bsearch: 8
|
||||||
|
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||||
|
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||||
|
# <0: for decoding, use full chunk.
|
||||||
|
# >0: for decoding, use fixed chunk size as set.
|
||||||
|
# 0: used for training, it's prohibited here.
|
||||||
|
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||||
|
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||||
|
|
||||||
|
|
@ -0,0 +1,111 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# bpemode (unigram or bpe)
|
||||||
|
nbpe=8000
|
||||||
|
bpemode=unigram
|
||||||
|
bpeprefix="data/bpe_${bpemode}_${nbpe}"
|
||||||
|
DATA_DIR=
|
||||||
|
|
||||||
|
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||||
|
|
||||||
|
|
||||||
|
mkdir -p data
|
||||||
|
TARGET_DIR=${MAIN_ROOT}/examples/dataset
|
||||||
|
mkdir -p ${TARGET_DIR}
|
||||||
|
|
||||||
|
if [ ! -d ${SOURCE_DIR} ]; then
|
||||||
|
echo "Error: Dataset is not avaiable. Please download and unzip the dataset"
|
||||||
|
echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0"
|
||||||
|
echo "The tree of the directory should be:"
|
||||||
|
echo "."
|
||||||
|
echo "|-- En-Zh"
|
||||||
|
echo "|-- test-segment"
|
||||||
|
echo " |-- tst2010"
|
||||||
|
echo " |-- ..."
|
||||||
|
echo "|-- train-split"
|
||||||
|
echo " |-- train-segment"
|
||||||
|
echo "|-- README.md"
|
||||||
|
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
# generate manifests
|
||||||
|
python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \
|
||||||
|
--manifest_prefix="data/manifest" \
|
||||||
|
--src_dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
echo "Complete raw data pre-process."
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# build vocabulary
|
||||||
|
python3 ${MAIN_ROOT}/utils/build_vocab.py \
|
||||||
|
--unit_type "spm" \
|
||||||
|
--spm_vocab_size=${nbpe} \
|
||||||
|
--spm_mode ${bpemode} \
|
||||||
|
--spm_model_prefix ${bpeprefix} \
|
||||||
|
--vocab_path="data/vocab.txt" \
|
||||||
|
--text_keys 'text' 'text1' \
|
||||||
|
--manifest_paths="data/manifest.train.raw"
|
||||||
|
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Build vocabulary failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# compute mean and stddev for normalizer
|
||||||
|
num_workers=$(nproc)
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
|
||||||
|
--manifest_path="data/manifest.train.raw" \
|
||||||
|
--num_samples=-1 \
|
||||||
|
--specgram_type="fbank" \
|
||||||
|
--feat_dim=80 \
|
||||||
|
--delta_delta=false \
|
||||||
|
--sample_rate=16000 \
|
||||||
|
--stride_ms=10.0 \
|
||||||
|
--window_ms=25.0 \
|
||||||
|
--use_dB_normalization=False \
|
||||||
|
--num_workers=${num_workers} \
|
||||||
|
--output_path="data/mean_std.json"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Compute mean and stddev failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# format manifest with tokenids, vocab size
|
||||||
|
for set in train dev test; do
|
||||||
|
{
|
||||||
|
python3 ${MAIN_ROOT}/utils/format_triplet_data.py \
|
||||||
|
--feat_type "raw" \
|
||||||
|
--cmvn_path "data/mean_std.json" \
|
||||||
|
--unit_type "spm" \
|
||||||
|
--spm_model_prefix ${bpeprefix} \
|
||||||
|
--vocab_path="data/vocab.txt" \
|
||||||
|
--manifest_path="data/manifest.${set}.raw" \
|
||||||
|
--output_path="data/manifest.${set}"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Formt mnaifest failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}&
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Ted En-Zh Data preparation done."
|
||||||
|
exit 0
|
@ -0,0 +1,35 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
if [ $# != 2 ];then
|
||||||
|
echo "usage: ${0} config_path ckpt_path_prefix"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
device=gpu
|
||||||
|
if [ ngpu == 0 ];then
|
||||||
|
device=cpu
|
||||||
|
fi
|
||||||
|
config_path=$1
|
||||||
|
ckpt_prefix=$2
|
||||||
|
|
||||||
|
for type in fullsentence; do
|
||||||
|
echo "decoding ${type}"
|
||||||
|
batch_size=32
|
||||||
|
python3 -u ${BIN_DIR}/test.py \
|
||||||
|
--device ${device} \
|
||||||
|
--nproc 1 \
|
||||||
|
--config ${config_path} \
|
||||||
|
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||||
|
--checkpoint_path ${ckpt_prefix} \
|
||||||
|
--opts decoding.decoding_method ${type} decoding.batch_size ${batch_size}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# != 2 ];then
|
||||||
|
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
ckpt_name=$2
|
||||||
|
|
||||||
|
device=gpu
|
||||||
|
if [ ${ngpu} == 0 ];then
|
||||||
|
device=cpu
|
||||||
|
fi
|
||||||
|
echo "using ${device}..."
|
||||||
|
|
||||||
|
mkdir -p exp
|
||||||
|
|
||||||
|
python3 -u ${BIN_DIR}/train.py \
|
||||||
|
--device ${device} \
|
||||||
|
--nproc ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--output exp/${ckpt_name}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,14 @@
|
|||||||
|
export MAIN_ROOT=${PWD}/../../
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||||
|
|
||||||
|
|
||||||
|
MODEL=u2_st
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
|
@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
conf_path=conf/transformer_joint_noam.yaml
|
||||||
|
avg_num=5
|
||||||
|
data_path=./TED-En-Zh # path to unzipped data
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
avg_ckpt=avg_${avg_num}
|
||||||
|
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||||
|
echo "checkpoint name ${ckpt}"
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
bash ./local/data.sh --DATA_DIR ${data_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `exp` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# avg n best model
|
||||||
|
../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# test ckpt avg_n
|
||||||
|
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
# export ckpt avg_n
|
||||||
|
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
|
||||||
|
fi
|
@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""format manifest with more metadata."""
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
|
||||||
|
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from deepspeech.frontend.utility import load_cmvn
|
||||||
|
from deepspeech.frontend.utility import read_manifest
|
||||||
|
from deepspeech.utils.utility import add_arguments
|
||||||
|
from deepspeech.utils.utility import print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
|
||||||
|
add_arg('cmvn_path', str,
|
||||||
|
'examples/librispeech/data/mean_std.json',
|
||||||
|
"Filepath of cmvn.")
|
||||||
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
||||||
|
add_arg('vocab_path', str,
|
||||||
|
'examples/librispeech/data/vocab.txt',
|
||||||
|
"Filepath of the vocabulary.")
|
||||||
|
add_arg('manifest_paths', str,
|
||||||
|
None,
|
||||||
|
"Filepaths of manifests for building vocabulary. "
|
||||||
|
"You can provide multiple manifest files.",
|
||||||
|
nargs='+',
|
||||||
|
required=True)
|
||||||
|
# bpe
|
||||||
|
add_arg('spm_model_prefix', str, None,
|
||||||
|
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
|
||||||
|
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
|
||||||
|
# yapf: disable
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print_arguments(args, globals())
|
||||||
|
fout = open(args.output_path, 'w', encoding='utf-8')
|
||||||
|
|
||||||
|
# get feat dim
|
||||||
|
mean, std = load_cmvn(args.cmvn_path, filetype='json')
|
||||||
|
feat_dim = mean.shape[0] #(D)
|
||||||
|
print(f"Feature dim: {feat_dim}")
|
||||||
|
|
||||||
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
|
||||||
|
vocab_size = text_feature.vocab_size
|
||||||
|
print(f"Vocab size: {vocab_size}")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for manifest_path in args.manifest_paths:
|
||||||
|
manifest_jsons = read_manifest(manifest_path)
|
||||||
|
for line_json in manifest_jsons:
|
||||||
|
# text: translation text, text1: transcript text.
|
||||||
|
# Currently only support joint-vocab, will add separate vocabs setting.
|
||||||
|
line = line_json['text']
|
||||||
|
tokens = text_feature.tokenize(line)
|
||||||
|
tokenids = text_feature.featurize(line)
|
||||||
|
line_json['token'] = tokens
|
||||||
|
line_json['token_id'] = tokenids
|
||||||
|
line_json['token_shape'] = (len(tokenids), vocab_size)
|
||||||
|
line = line_json['text1']
|
||||||
|
tokens = text_feature.tokenize(line)
|
||||||
|
tokenids = text_feature.featurize(line)
|
||||||
|
line_json['token1'] = tokens
|
||||||
|
line_json['token_id1'] = tokenids
|
||||||
|
line_json['token_shape1'] = (len(tokenids), vocab_size)
|
||||||
|
feat_shape = line_json['feat_shape']
|
||||||
|
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
|
||||||
|
if args.feat_type == 'raw':
|
||||||
|
feat_shape.append(feat_dim)
|
||||||
|
else: # kaldi
|
||||||
|
raise NotImplementedError('no support kaldi feat now!')
|
||||||
|
fout.write(json.dumps(line_json) + '\n')
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"Examples number: {count}")
|
||||||
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in new issue