diff --git a/deepspeech/exps/u2_st/__init__.py b/deepspeech/exps/u2_st/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/deepspeech/exps/u2_st/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/exps/u2_st/bin/export.py b/deepspeech/exps/u2_st/bin/export.py new file mode 100644 index 00000000..f566ba5b --- /dev/null +++ b/deepspeech/exps/u2_st/bin/export.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export for U2 model.""" +from deepspeech.exps.u2_st.config import get_cfg_defaults +from deepspeech.exps.u2_st.model import U2STTester as Tester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_export() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2_st/bin/test.py b/deepspeech/exps/u2_st/bin/test.py new file mode 100644 index 00000000..d66c7a26 --- /dev/null +++ b/deepspeech/exps/u2_st/bin/test.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" +import cProfile + +from deepspeech.exps.u2_st.config import get_cfg_defaults +from deepspeech.exps.u2_st.model import U2STTester as Tester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + +# TODO(hui zhang): dynamic load + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats('test.profile') diff --git a/deepspeech/exps/u2_st/bin/train.py b/deepspeech/exps/u2_st/bin/train.py new file mode 100644 index 00000000..86a0f000 --- /dev/null +++ b/deepspeech/exps/u2_st/bin/train.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for U2 model.""" +import cProfile +import os + +from paddle import distributed as dist + +from deepspeech.exps.u2_st.config import get_cfg_defaults +from deepspeech.exps.u2_st.model import U2STTrainer as Trainer +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.device == "gpu" and args.nprocs > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats(os.path.join(args.output, 'train.profile')) diff --git a/deepspeech/exps/u2_st/config.py b/deepspeech/exps/u2_st/config.py new file mode 100644 index 00000000..b1b7b357 --- /dev/null +++ b/deepspeech/exps/u2_st/config.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from yacs.config import CfgNode + +from deepspeech.exps.u2_st.model import U2STTester +from deepspeech.exps.u2_st.model import U2STTrainer +from deepspeech.io.collator_st import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.models.u2_st import U2STModel + +_C = CfgNode() + +_C.data = ManifestDataset.params() + +_C.collator = SpeechCollator.params() + +_C.model = U2STModel.params() + +_C.training = U2STTrainer.params() + +_C.decoding = U2STTester.params() + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + config = _C.clone() + config.set_new_allowed(True) + return config diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py new file mode 100644 index 00000000..867d1899 --- /dev/null +++ b/deepspeech/exps/u2_st/model.py @@ -0,0 +1,680 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import json +import os +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from yacs.config import CfgNode + +from deepspeech.io.collator_st import KaldiPrePorocessedCollator +from deepspeech.io.collator_st import SpeechCollator +from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator +from deepspeech.io.collator_st import TripletSpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.dataset import TripletManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.u2_st import U2STModel +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.scheduler import WarmupLR +from deepspeech.training.trainer import Trainer +from deepspeech.utils import bleu_score +from deepspeech.utils import ctc_utils +from deepspeech.utils import error_rate +from deepspeech.utils import layer_tools +from deepspeech.utils import mp_tools +from deepspeech.utils import text_grid +from deepspeech.utils import utility +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2STTrainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + n_epoch=50, # train epochs + log_interval=100, # steps + accum_grad=1, # accum grad by # steps + global_grad_clip=5.0, # the global norm clip + )) + default.optim = 'adam' + default.optim_conf = CfgNode( + dict( + lr=5e-4, # learning rate + weight_decay=1e-6, # the coeff of weight decay + )) + default.scheduler = 'warmuplr' + default.scheduler_conf = CfgNode( + dict( + warmup_steps=25000, + lr_decay=1.0, # learning rate decay + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training + start = time.time() + utt, audio, audio_len, text, text_len = batch_data + if isinstance(text, list) and isinstance(text_len, list): + # joint training with ASR. Two decoding texts [translation, transcription] + text, text_transcript = text + text_len, text_transcript_len = text_len + loss, st_loss, attention_loss, ctc_loss = self.model( + audio, audio_len, text, text_len, text_transcript, + text_transcript_len) + else: + loss, st_loss, attention_loss, ctc_loss = self.model( + audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` + loss /= train_conf.accum_grad + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + losses_np = {'loss': float(loss) * train_conf.accum_grad} + losses_np['st_loss'] = float(st_loss) + if attention_loss: + losses_np['att_loss'] = float(attention_loss) + if ctc_loss: + losses_np['ctc_loss'] = float(ctc_loss) + + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.lr_scheduler.step() + self.iteration += 1 + + iteration_time = time.time() - start + + if (batch_index + 1) % train_conf.log_interval == 0: + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + losses_np_v = losses_np.copy() + losses_np_v.update({"lr": self.lr_scheduler()}) + self.visualizer.add_scalars("step", losses_np_v, + self.iteration - 1) + + @paddle.no_grad() + def valid(self): + self.model.eval() + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + utt, audio, audio_len, text, text_len = batch + if isinstance(text, list) and isinstance(text_len, list): + text, text_transcript = text + text_len, text_transcript_len = text_len + loss, st_loss, attention_loss, ctc_loss = self.model( + audio, audio_len, text, text_len, text_transcript, + text_transcript_len) + else: + loss, st_loss, attention_loss, ctc_loss = self.model( + audio, audio_len, text, text_len) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(st_loss) * num_utts + valid_losses['val_loss'].append(float(st_loss)) + if attention_loss: + valid_losses['val_att_loss'].append(float(attention_loss)) + if ctc_loss: + valid_losses['val_ctc_loss'].append(float(ctc_loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump['val_history_st_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_dump.items()) + logger.info(msg) + + logger.info('Rank {} Val info st_val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + + def train(self): + """The training process control by step.""" + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # script_model = paddle.jit.to_static(self.model) + # script_model_path = str(self.checkpoint_dir / 'init') + # paddle.jit.save(script_model, script_model_path) + + from_scratch = self.resume_or_scratch() + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag='init') + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + while self.epoch < self.config.training.n_epoch: + self.model.train() + try: + data_start_time = time.time() + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts + + logger.info( + 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) + if self.visualizer: + self.visualizer.add_scalars( + 'epoch', {'cv_loss': cv_loss, + 'lr': self.lr_scheduler()}, self.epoch) + self.save(tag=self.epoch, infos={'val_loss': cv_loss}) + self.new_epoch() + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + # train/valid dataset, return token ids + Dataset = TripletManifestDataset if config.model.model_conf.asr_weight > 0. else ManifestDataset + config.data.manifest = config.data.train_manifest + train_dataset = Dataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = Dataset.from_config(config) + + if config.collator.raw_wav: + if config.model.model_conf.asr_weight > 0.: + Collator = TripletSpeechCollator + TestCollator = SpeechCollator + else: + TestCollator = Collator = SpeechCollator + # Not yet implement the mtl loader for raw_wav. + else: + if config.model.model_conf.asr_weight > 0.: + Collator = TripletKaldiPrePorocessedCollator + TestCollator = KaldiPrePorocessedCollator + else: + TestCollator = Collator = KaldiPrePorocessedCollator + + collate_fn_train = Collator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = Collator.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + + # test dataset, return raw text + config.data.manifest = config.data.test_manifest + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + # config.data.min_input_len = 0.0 # second + # config.data.max_input_len = float('inf') # second + # config.data.min_output_len = 0.0 # tokens + # config.data.max_output_len = float('inf') # tokens + # config.data.min_output_input_ratio = 0.00 + # config.data.max_output_input_ratio = float('inf') + test_dataset = ManifestDataset.from_config(config) + # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=TestCollator.from_config(config)) + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=TestCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config.model + model_conf.defrost() + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model_conf.freeze() + model = U2STModel.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip) + weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay) + + if scheduler_type == 'expdecaylr': + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=optim_conf.lr, + gamma=scheduler_conf.lr_decay, + verbose=False) + elif scheduler_type == 'warmuplr': + lr_scheduler = WarmupLR( + learning_rate=optim_conf.lr, + warmup_steps=scheduler_conf.warmup_steps, + verbose=False) + elif scheduler_type == 'noam': + lr_scheduler = paddle.optimizer.lr.NoamDecay( + learning_rate=optim_conf.lr, + d_model=model_conf.encoder_conf.output_size, + warmup_steps=scheduler_conf.warmup_steps, + verbose=False) + else: + raise ValueError(f"Not support scheduler: {scheduler_type}") + + if optim_type == 'adam': + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=weight_decay, + grad_clip=grad_clip) + else: + raise ValueError(f"Not support optim: {optim_type}") + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + +class U2STTester(U2STTrainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # decoding config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search', + # 'ctc_prefix_beam_search', 'attention_rescoring' + error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=10, # Beam search width. + batch_size=16, # decoding batch size + ctc_weight=0.0, # ctc weight for attention rescoring decode mode. + decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. + simulate_streaming=False, # simulate streaming inference. Defaults to False. + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_translation_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + bleu_func, + fout=None): + cfg = self.config.decoding + len_refs, num_ins = 0, 0 + + start_time = time.time() + text_feature = self.test_loader.collate_fn.text_feature + + refs = [ + "".join(chr(t) for t in text[:text_len]) + for text, text_len in zip(texts, texts_len) + ] + # from IPython import embed + # import os + # embed() + # os._exit(0) + hyps = self.model.decode( + audio, + audio_len, + text_feature=text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + decode_time = time.time() - start_time + + for utt, target, result in zip(utts, refs, hyps): + len_refs += len(target.split()) + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) + logger.info("One example BLEU = %s" % + (bleu_func([result], [[target]]).prec_str)) + + return dict( + hyps=hyps, + refs=refs, + bleu=bleu_func(hyps, [refs]).score, + len_refs=len_refs, + num_ins=num_ins, # num examples + num_frames=audio_len.sum().numpy().item(), + decode_time=decode_time) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + assert self.args.result_file + self.model.eval() + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + + cfg = self.config.decoding + bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu + + stride_ms = self.test_loader.collate_fn.stride_ms + hyps, refs = [], [] + len_refs, num_ins = 0, 0 + num_frames = 0.0 + num_time = 0.0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + metrics = self.compute_translation_metrics( + *batch, bleu_func=bleu_func, fout=fout) + hyps += metrics['hyps'] + refs += metrics['refs'] + bleu = metrics['bleu'] + num_frames += metrics['num_frames'] + num_time += metrics["decode_time"] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + rtf = num_time / (num_frames * stride_ms) + logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu)) + + rtf = num_time / (num_frames * stride_ms) + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "RTF: {}, ".format(rtf) + msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs]))) + logger.info(msg) + bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu' + err_type_str = "BLEU" + with open(bleu_meta_path, 'w') as f: + data = json.dumps({ + "epoch": + self.epoch, + "step": + self.iteration, + "rtf": + rtf, + err_type_str: + bleu_func(hyps, [refs]).score, + "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0, + "process_hour": + num_time / 1000.0 / 3600.0, + "num_examples": + num_ins, + "decode_method": + self.config.decoding.decoding_method, + }) + f.write(data + '\n') + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + sys.exit(-1) + + @paddle.no_grad() + def align(self): + if self.config.decoding.batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + # xxx.align + assert self.args.result_file and self.args.result_file.endswith( + '.align') + + self.model.eval() + logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") + + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list + with open(self.args.result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(self.align_loader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = self.model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = ctc_utils.forced_align(ctc_probs, target) + logger.info("align ids", key[0], alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info("align tokens", key[0], align_segs) + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + tierformat = text_grid.align_to_tierformat( + align_segs, subsample, token_dict) + # write tier + align_output_path = os.path.join( + os.path.dirname(self.args.result_file), "align") + tier_path = os.path.join(align_output_path, key[0] + ".tier") + with open(tier_path, 'w') as f: + f.writelines(tierformat) + # write textgrid + textgrid_path = os.path.join(align_output_path, + key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=textgrid_path) + + def run_align(self): + self.resume_or_scratch() + try: + self.align() + except KeyboardInterrupt: + sys.exit(-1) + + def load_inferspec(self): + """infer model and input spec. + + Returns: + nn.Layer: inference model + List[paddle.static.InputSpec]: input spec. + """ + from deepspeech.models.u2 import U2InferModel + infer_model = U2InferModel.from_pretrained(self.test_loader, + self.config.model.clone(), + self.args.checkpoint_path) + feat_dim = self.test_loader.collate_fn.feature_size + input_spec = [ + paddle.static.InputSpec(shape=[1, None, feat_dim], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[1], + dtype='int64'), # audio_length, [B] + ] + return infer_model, input_spec + + def export(self): + infer_model, input_spec = self.load_inferspec() + assert isinstance(input_spec, list), type(input_spec) + infer_model.eval() + static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + sys.exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir diff --git a/deepspeech/io/collator_st.py b/deepspeech/io/collator_st.py new file mode 100644 index 00000000..34933312 --- /dev/null +++ b/deepspeech/io/collator_st.py @@ -0,0 +1,666 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from collections import namedtuple +from typing import Optional +from typing import Tuple + +import kaldiio +import numpy as np +from yacs.config import CfgNode + +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.frontend.normalizer import FeatureNormalizer +from deepspeech.frontend.speech import SpeechSegment +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.io.utility import pad_sequence +from deepspeech.utils.log import Log + +__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"] + +logger = Log(__name__).getlog() + +# namedtupe need global for pickle. +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +class SpeechCollator(): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + mean_std_filepath="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, # feature dither + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. + """ + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.collator + assert 'specgram_type' in config.collator + assert 'n_fft' in config.collator + assert config.collator + + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + def __init__( + self, + aug_file, + mean_std_filepath, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): + """SpeechCollator Collator + + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + mean_std_filepath (str): mean and std file path, which suffix is *.npy + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + stride_ms (float, optional): stride size in ms. Defaults to 10.0. + window_ms (float, optional): window size in ms. Defaults to 20.0. + n_fft (int, optional): fft points for rfft. Defaults to None. + max_freq (int, optional): max cut freq. Defaults to None. + target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. + specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. + delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. + use_dB_normalization (bool, optional): do dB normalization. Defaults to True. + target_dB (int, optional): target dB. Defaults to -20. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. + """ + self._keep_transcription_text = keep_transcription_text + + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), random_seed=random_seed) + + self._normalizer = FeatureNormalizer( + mean_std_filepath) if mean_std_filepath else None + + self._stride_ms = stride_ms + self._target_sample_rate = target_sample_rate + + self._speech_featurizer = SpeechFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=stride_ms, + window_ms=window_ms, + n_fft=n_fft, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=dither) + + def _parse_tar(self, file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + def _subfile_from_tar(self, file): + """Get subfile object from tar. + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) + + def process_utterance(self, audio_file, translation): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param translation: translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), translation) + else: + speech_segment = SpeechSegment.from_file(audio_file, translation) + + # audio augment + self._augmentation_pipeline.transform_audio(speech_segment) + + specgram, translation_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + specgram = specgram.transpose([1, 0]) + return specgram, translation_part + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + text (List[int] or str): shape (U,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + text : (B, Umax) + text_lens: (B) + """ + audios = [] + audio_lens = [] + texts = [] + text_lens = [] + utts = [] + for utt, audio, text in batch: + audio, text = self.process_utterance(audio, text) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [] + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens = [ord(t) for t in text] + else: + tokens = text # token ids + tokens = tokens if isinstance(tokens, np.ndarray) else np.array( + tokens, dtype=np.int64) + texts.append(tokens) + text_lens.append(tokens.shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_texts = pad_sequence( + texts, padding_value=IGNORE_ID).astype(np.int64) + text_lens = np.array(text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, padded_texts, text_lens + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + return self._speech_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._speech_featurizer.text_feature + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + @property + def stride_ms(self): + return self._speech_featurizer.stride_ms + + +class TripletSpeechCollator(SpeechCollator): + def process_utterance(self, audio_file, translation, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param translation: translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), translation) + else: + speech_segment = SpeechSegment.from_file(audio_file, translation) + + # audio augment + self._augmentation_pipeline.transform_audio(speech_segment) + + specgram, translation_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + transcript_part = self._speech_featurizer._text_featurizer.featurize( + transcript) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + specgram = specgram.transpose([1, 0]) + return specgram, translation_part, transcript_part + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + text (List[int] or str): shape (U,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + text : (B, Umax) + text_lens: (B) + """ + audios = [] + audio_lens = [] + translation_text = [] + translation_text_lens = [] + transcription_text = [] + transcription_text_lens = [] + + utts = [] + for utt, audio, translation, transcription in batch: + audio, translation, transcription = self.process_utterance( + audio, translation, transcription) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [[], []] + for idx, text in enumerate([translation, transcription]): + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens[idx] = [ord(t) for t in text] + else: + tokens[idx] = text # token ids + tokens[idx] = tokens[idx] if isinstance( + tokens[idx], np.ndarray) else np.array( + tokens[idx], dtype=np.int64) + translation_text.append(tokens[0]) + translation_text_lens.append(tokens[0].shape[0]) + transcription_text.append(tokens[1]) + transcription_text_lens.append(tokens[1].shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_translation = pad_sequence( + translation_text, padding_value=IGNORE_ID).astype(np.int64) + translation_lens = np.array(translation_text_lens).astype(np.int64) + padded_transcription = pad_sequence( + transcription_text, padding_value=IGNORE_ID).astype(np.int64) + transcription_lens = np.array(transcription_text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, ( + padded_translation, padded_transcription), (translation_lens, + transcription_lens) + + +class KaldiPrePorocessedCollator(SpeechCollator): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + feat_dim=0, + stride_ms=10.0, + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. + """ + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'vocab_filepath' in config.collator + assert config.collator + + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + feat_dim=config.collator.feat_dim, + stride_ms=config.collator.stride_ms, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + def __init__(self, + aug_file, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + feat_dim=0, + stride_ms=10.0, + keep_transcription_text=True): + """SpeechCollator Collator + + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. + """ + self._keep_transcription_text = keep_transcription_text + self._feat_dim = feat_dim + self._stride_ms = stride_ms + + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), random_seed=random_seed) + + self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, + spm_model_prefix) + + def process_utterance(self, audio_file, translation): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of kaldi processed feature. + :type audio_file: str | file + :param translation: Translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + specgram = kaldiio.load_mat(audio_file) + specgram = specgram.transpose([1, 0]) + assert specgram.shape[ + 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[0]) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + + specgram = specgram.transpose([1, 0]) + if self._keep_transcription_text: + return specgram, translation + else: + text_ids = self._text_featurizer.featurize(translation) + return specgram, text_ids + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._text_featurizer.vocab_size + + @property + def vocab_list(self): + return self._text_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._text_featurizer.vocab_dict + + @property + def text_feature(self): + return self._text_featurizer + + @property + def feature_size(self): + return self._feat_dim + + @property + def stride_ms(self): + return self._stride_ms + + +class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): + def process_utterance(self, audio_file, translation, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of kali processed feature. + :type audio_file: str | file + :param translation: Translation text. + :type translation: str + :param transcript: Transcription text. + :type transcript: str + :return: Tuple of audio feature tensor and data of translation and transcription parts, + where translation and transcription parts could be token ids or text. + :rtype: tuple of (2darray, (list, list)) + """ + specgram = kaldiio.load_mat(audio_file) + specgram = specgram.transpose([1, 0]) + assert specgram.shape[ + 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[0]) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + + specgram = specgram.transpose([1, 0]) + if self._keep_transcription_text: + return specgram, translation, transcript + else: + translation_text_ids = self._text_featurizer.featurize(translation) + transcript_text_ids = self._text_featurizer.featurize(transcript) + return specgram, translation_text_ids, transcript_text_ids + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + translation (List[int] or str): shape (U,) + transcription (List[int] or str): shape (V,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + translation_text : (B, Umax) + translation_text_lens: (B) + transcription_text : (B, Vmax) + transcription_text_lens: (B) + """ + audios = [] + audio_lens = [] + translation_text = [] + translation_text_lens = [] + transcription_text = [] + transcription_text_lens = [] + + utts = [] + for utt, audio, translation, transcription in batch: + audio, translation, transcription = self.process_utterance( + audio, translation, transcription) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [[], []] + for idx, text in enumerate([translation, transcription]): + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens[idx] = [ord(t) for t in text] + else: + tokens[idx] = text # token ids + tokens[idx] = tokens[idx] if isinstance( + tokens[idx], np.ndarray) else np.array( + tokens[idx], dtype=np.int64) + translation_text.append(tokens[0]) + translation_text_lens.append(tokens[0].shape[0]) + transcription_text.append(tokens[1]) + transcription_text_lens.append(tokens[1].shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_translation = pad_sequence( + translation_text, padding_value=IGNORE_ID).astype(np.int64) + translation_lens = np.array(translation_text_lens).astype(np.int64) + padded_transcription = pad_sequence( + transcription_text, padding_value=IGNORE_ID).astype(np.int64) + transcription_lens = np.array(transcription_text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, ( + padded_translation, padded_transcription), (translation_lens, + transcription_lens) diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 3fc4e988..ac7be1f9 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -19,9 +19,7 @@ from yacs.config import CfgNode from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log -__all__ = [ - "ManifestDataset", -] +__all__ = ["ManifestDataset", "TripletManifestDataset"] logger = Log(__name__).getlog() @@ -105,3 +103,16 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] return instance["utt"], instance["feat"], instance["text"] + + +class TripletManifestDataset(ManifestDataset): + """ + For Joint Training of Speech Translation and ASR. + text: translation, + text1: transcript. + """ + + def __getitem__(self, idx): + instance = self._manifest[idx] + return instance["utt"], instance["feat"], instance["text"], instance[ + "text1"] diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py new file mode 100644 index 00000000..5eea139b --- /dev/null +++ b/deepspeech/models/u2_st.py @@ -0,0 +1,734 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""U2 ASR Model +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +(https://arxiv.org/pdf/2012.05481.pdf) +""" +import sys +import time +from collections import defaultdict +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import paddle +from paddle import jit +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.frontend.utility import load_cmvn +from deepspeech.modules.cmvn import GlobalCMVN +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.modules.decoder import TransformerDecoder +from deepspeech.modules.encoder import ConformerEncoder +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.loss import LabelSmoothingLoss +from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask +from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools +from deepspeech.utils.ctc_utils import remove_duplicates_and_blank +from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import add_sos_eos +from deepspeech.utils.tensor_utils import pad_sequence +from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import log_add + +__all__ = ["U2STModel", "U2STInferModel"] + +logger = Log(__name__).getlog() + + +class U2STBaseModel(nn.Module): + """CTC-Attention hybrid Encoder-Decoder model""" + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # network architecture + default = CfgNode() + # allow add new item when merge_with_file + default.cmvn_file = "" + default.cmvn_file_type = "json" + default.input_dim = 0 + default.output_dim = 0 + # encoder related + default.encoder = 'transformer' + default.encoder_conf = CfgNode( + dict( + output_size=256, # dimension of attention + attention_heads=4, + linear_units=2048, # the number of units of position-wise feed forward + num_blocks=12, # the number of encoder blocks + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before=True, + # use_cnn_module=True, + # cnn_module_kernel=15, + # activation_type='swish', + # pos_enc_layer_type='rel_pos', + # selfattention_layer_type='rel_selfattn', + )) + # decoder related + default.decoder = 'transformer' + default.decoder_conf = CfgNode( + dict( + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, )) + # hybrid CTC/attention + default.model_conf = CfgNode( + dict( + asr_weight=0.0, + ctc_weight=0.0, + lsm_weight=0.1, # label smoothing option + length_normalized_loss=False, )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + vocab_size: int, + encoder: TransformerEncoder, + st_decoder: TransformerDecoder, + decoder: TransformerDecoder=None, + ctc: CTCDecoder=None, + ctc_weight: float=0.0, + asr_weight: float=0.0, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.asr_weight = asr_weight + + self.encoder = encoder + self.st_decoder = st_decoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, ) + + def forward( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + asr_text: paddle.Tensor=None, + asr_text_lengths: paddle.Tensor=None, + ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + Returns: + total_loss, attention_loss, ctc_loss + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + start = time.time() + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_time = time.time() - start + #logger.debug(f"encoder time: {encoder_time}") + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] + + # 2a. ST-decoder branch + start = time.time() + loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text, + text_lengths) + decoder_time = time.time() - start + + loss_asr_att = None + loss_asr_ctc = None + # 2b. ASR Attention-decoder branch + if self.asr_weight > 0.: + if self.ctc_weight != 1.0: + start = time.time() + loss_asr_att, acc_att = self._calc_att_loss( + encoder_out, encoder_mask, asr_text, asr_text_lengths) + decoder_time = time.time() - start + + # 2c. CTC branch + if self.ctc_weight != 0.0: + start = time.time() + loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text, + asr_text_lengths) + ctc_time = time.time() - start + + if loss_asr_ctc is None: + loss_asr = loss_asr_att + elif loss_asr_att is None: + loss_asr = loss_asr_ctc + else: + loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight + ) * loss_asr_att + loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st + else: + loss = loss_st + return loss, loss_st, loss_asr_att, loss_asr_ctc + + def _calc_st_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _calc_att_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encoder pass. + + Args: + speech (paddle.Tensor): [B, Tmax, D] + speech_lengths (paddle.Tensor): [B] + decoding_chunk_size (int, optional): chuck size. Defaults to -1. + num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1. + simulate_streaming (bool, optional): streaming or not. Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), + encoder hiddens mask (B, 1, Tmax). + """ + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def translate( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int=10, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> paddle.Tensor: + """ Apply beam search on attention decoder + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + paddle.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.place + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = paddle.ones( + [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) + # log scale score + scores = paddle.to_tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) + scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) + cache: Optional[List[paddle.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: + break + + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.st_decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + + # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + 1, beam_size) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = paddle.index_select( + top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = paddle.index_select( + hyps, index=best_hyps_index, axis=0) # (B*N, i) + hyps = paddle.cat( + (last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_index = paddle.argmax(scores, axis=-1).long() # (B) + best_hyps_index = best_index + paddle.arange( + batch_size, dtype=paddle.long) * beam_size + best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) + best_hyps = best_hyps[:, 1:] + return best_hyps + + @jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @jit.export + def forward_encoder_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[paddle.Tensor]=None, + elayers_output_cache: Optional[List[paddle.Tensor]]=None, + conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + Args: + xs (paddle.Tensor): chunk input + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output, it ranges from time 0 to current chunk. + paddle.Tensor: subsampling cache + List[paddle.Tensor]: attention cache + List[paddle.Tensor]: conformer cnn cache + """ + return self.encoder.forward_chunk( + xs, offset, required_cache_size, subsampling_cache, + elayers_output_cache, conformer_cnn_cache) + + @jit.export + def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (paddle.Tensor): encoder output + Returns: + paddle.Tensor: activation before ctc + """ + return self.ctc.log_softmax(xs) + + @jit.export + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining, (B, T) + hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) + Returns: + paddle.Tensor: decoder output, (B, L) + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, + hyps_lens) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out + + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + lang_model_path: str, + beam_alpha: float, + beam_beta: float, + beam_size: int, + cutoff_prob: float, + cutoff_top_n: int, + num_processes: int, + ctc_weight: float=0.0, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False): + """u2 decoding. + + Args: + feats (Tenosr): audio features, (B, T, D) + feats_lengths (Tenosr): (B) + text_feature (TextFeaturizer): text feature object. + decoding_method (str): decoding mode, e.g. + 'fullsentence', + 'simultaneous' + lang_model_path (str): lm path. + beam_alpha (float): lm weight. + beam_beta (float): length penalty. + beam_size (int): beam size for search + cutoff_prob (float): for prune. + cutoff_top_n (int): for prune. + num_processes (int): + ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. + decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): + number of left chunks for decoding. Defaults to -1. + simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + + Raises: + ValueError: when not support decoding_method. + + Returns: + List[List[int]]: transcripts. + """ + batch_size = feats.size(0) + + if decoding_method == 'fullsentence': + hyps = self.translate( + feats, + feats_lengths, + beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + else: + raise ValueError(f"Not support decoding method: {decoding_method}") + + res = [text_feature.defeaturize(hyp) for hyp in hyps] + return res + + +class U2STModel(U2STBaseModel): + def __init__(self, configs: dict): + vocab_size, encoder, decoder = U2STModel._init_from_config(configs) + + if isinstance(decoder, Tuple): + st_decoder, asr_decoder, ctc = decoder + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=st_decoder, + decoder=asr_decoder, + ctc=ctc, + **configs['model_conf']) + else: + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=decoder, + **configs['model_conf']) + + @classmethod + def _init_from_config(cls, configs: dict): + """init sub module for model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + """ + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['cmvn_file_type']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + assert input_dim != 0, input_dim + assert vocab_size != 0, vocab_size + + encoder_type = configs.get('encoder', 'transformer') + logger.info(f"U2 Encoder type: {encoder_type}") + if encoder_type == 'transformer': + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + raise ValueError(f"not support encoder type:{encoder_type}") + + st_decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + + asr_weight = configs['model_conf']['asr_weight'] + logger.info(f"ASR Joint Training Weight: {asr_weight}") + + if asr_weight > 0.: + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + return vocab_size, encoder, (st_decoder, decoder, ctc) + else: + return vocab_size, encoder, st_decoder + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: U2STModel + """ + model = cls(configs) + return model + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataloader (paddle.io.DataLoader): not used. + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + config.defrost() + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + config.freeze() + model = cls.from_config(config) + + if checkpoint_path: + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + +class U2STInferModel(U2STModel): + def __init__(self, configs: dict): + super().__init__(configs) + + def forward(self, + feats, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False): + """export model function + + Args: + feats (Tensor): [B, T, D] + feats_lengths (Tensor): [B] + + Returns: + List[List[int]]: best path result + """ + return self.translate( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py new file mode 100644 index 00000000..580fbf61 --- /dev/null +++ b/deepspeech/utils/bleu_score.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides functions to calculate bleu score in different level. +e.g. wer for word-level, cer for char-level. +""" +import numpy as np +import sacrebleu + +__all__ = ['bleu', 'char_bleu'] + + +def bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in word-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference length is zero. + """ + + return sacrebleu.corpus_bleu(hypothesis, reference) + +def char_bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in char-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference number is zero. + """ + hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis] + reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ] + + return sacrebleu.corpus_bleu(hypothesis, reference) \ No newline at end of file diff --git a/examples/dataset/ted_en_zh/.gitignore b/examples/dataset/ted_en_zh/.gitignore new file mode 100644 index 00000000..ad6ab64a --- /dev/null +++ b/examples/dataset/ted_en_zh/.gitignore @@ -0,0 +1,6 @@ +*.tar.gz.* +manifest.* +*.md +EN-ZH/ +train-split/ +test-segment/ \ No newline at end of file diff --git a/examples/dataset/ted_en_zh/ted_en_zh.py b/examples/dataset/ted_en_zh/ted_en_zh.py new file mode 100644 index 00000000..08f15119 --- /dev/null +++ b/examples/dataset/ted_en_zh/ted_en_zh.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare Ted-En-Zh speech translation dataset + +Create manifest files from splited datased. +dev set: tst2010, test set: tst2015 +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os + +import soundfile + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--src_dir", + default="", + type=str, + help="Directory to kaldi splited data. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + + data_types_infos = [('train', 'train-split/train-segment', 'En-Zh/train.en-zh'), + ('dev', 'test-segment/tst2010', 'En-Zh/tst2010.en-zh'), + ('test', 'test-segment/tst2015', 'En-Zh/tst2015.en-zh')] + for data_info in data_types_infos: + dtype, audio_relative_dir, text_relative_path = data_info + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + text_path = os.path.join(data_dir, text_relative_path) + audio_dir = os.path.join(data_dir, audio_relative_dir) + + for line in codecs.open(text_path, 'r', 'utf-8', errors='ignore'): + line = line.strip() + if len(line) < 1: + continue + audio_id, trancription, translation = line.split('\t') + utt = audio_id.split('.')[0] + + audio_path = os.path.join(audio_dir, audio_id) + if os.path.exists(audio_path): + if os.path.getsize(audio_path) < 30000: + continue + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + json_lines.append( + json.dumps( + { + 'utt': utt, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': " ".join(translation.split()), + 'text1': " ".join(trancription.split()) + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(translation.split()) + total_num += 1 + if not total_num % 1000: + print(dtype, 'Processed:', total_num) + + manifest_path = manifest_path_prefix + '.' + dtype + '.raw' + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + +def prepare_dataset(src_dir, manifest_path=None): + """create manifest file.""" + if os.path.isdir(manifest_path): + manifest_path = os.path.join(manifest_path, 'manifest') + if manifest_path: + create_manifest(src_dir, manifest_path) + + +def main(): + if args.src_dir.startswith('~'): + args.src_dir = os.path.expanduser(args.src_dir) + + prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix) + + print("manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/ted_en_zh/conf/transformer.yaml b/examples/ted_en_zh/conf/transformer.yaml new file mode 100644 index 00000000..10a3e7f5 --- /dev/null +++ b/examples/ted_en_zh/conf/transformer.yaml @@ -0,0 +1,109 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train.tiny + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 # second + max_input_len: 3000.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/bpe_unigram_8000 + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.0 + ctc_weight: 0.0 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 120 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.004 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/ted_en_zh/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/conf/transformer_joint_noam.yaml new file mode 100644 index 00000000..ba384f8c --- /dev/null +++ b/examples/ted_en_zh/conf/transformer_joint_noam.yaml @@ -0,0 +1,111 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 # second + max_input_len: 3000.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/bpe_unigram_8000 + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.5 + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 120 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 2.5 + weight_decay: 1e-06 + scheduler: noam + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. + + diff --git a/examples/ted_en_zh/local/data.sh b/examples/ted_en_zh/local/data.sh new file mode 100755 index 00000000..0a5c58aa --- /dev/null +++ b/examples/ted_en_zh/local/data.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +# bpemode (unigram or bpe) +nbpe=8000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +DATA_DIR= + + +source ${MAIN_ROOT}/utils/parse_options.sh + + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +if [ ! -d ${SOURCE_DIR} ]; then + echo "Error: Dataset is not avaiable. Please download and unzip the dataset" + echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0" + echo "The tree of the directory should be:" + echo "." + echo "|-- En-Zh" + echo "|-- test-segment" + echo " |-- tst2010" + echo " |-- ..." + echo "|-- train-split" + echo " |-- train-segment" + echo "|-- README.md" + + exit 1 +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # generate manifests + python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \ + --manifest_prefix="data/manifest" \ + --src_dir="${DATA_DIR}" + + echo "Complete raw data pre-process." +fi + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type "spm" \ + --spm_vocab_size=${nbpe} \ + --spm_mode ${bpemode} \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --text_keys 'text' 'text1' \ + --manifest_paths="data/manifest.train.raw" + + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=-1 \ + --specgram_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --use_dB_normalization=False \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_triplet_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "spm" \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "Ted En-Zh Data preparation done." +exit 0 diff --git a/examples/ted_en_zh/local/test.sh b/examples/ted_en_zh/local/test.sh new file mode 100755 index 00000000..802bb13c --- /dev/null +++ b/examples/ted_en_zh/local/test.sh @@ -0,0 +1,35 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +for type in fullsentence; do + echo "decoding ${type}" + batch_size=32 + python3 -u ${BIN_DIR}/test.py \ + --device ${device} \ + --nproc 1 \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + +exit 0 diff --git a/examples/ted_en_zh/local/train.sh b/examples/ted_en_zh/local/train.sh new file mode 100755 index 00000000..f3eb98da --- /dev/null +++ b/examples/ted_en_zh/local/train.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +echo "using ${device}..." + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/examples/ted_en_zh/path.sh b/examples/ted_en_zh/path.sh new file mode 100644 index 00000000..881a5b91 --- /dev/null +++ b/examples/ted_en_zh/path.sh @@ -0,0 +1,14 @@ +export MAIN_ROOT=${PWD}/../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +MODEL=u2_st +export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin diff --git a/examples/ted_en_zh/run.sh b/examples/ted_en_zh/run.sh new file mode 100755 index 00000000..89048f3d --- /dev/null +++ b/examples/ted_en_zh/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/transformer_joint_noam.yaml +avg_num=5 +data_path=./TED-En-Zh # path to unzipped data +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh --DATA_DIR ${data_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + ../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit +fi diff --git a/utils/build_vocab.py b/utils/build_vocab.py index 76092b25..151d52f8 100755 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -44,6 +44,11 @@ add_arg('manifest_paths', str, "You can provide multiple manifest files.", nargs='+', required=True) +add_arg('text_keys', str, + 'text', + "keys of the text in manifest for building vocabulary. " + "You can provide multiple k.", + nargs='+') # bpe add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") @@ -58,10 +63,10 @@ def count_manifest(counter, text_feature, manifest_path): line = text_feature.tokenize(line_json['text']) counter.update(line) -def dump_text_manifest(fileobj, manifest_path): +def dump_text_manifest(fileobj, manifest_path, key='text'): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: - fileobj.write(line_json['text'] + "\n") + fileobj.write(line_json[key] + "\n") def main(): print_arguments(args, globals()) @@ -78,7 +83,9 @@ def main(): fp = tempfile.NamedTemporaryFile(mode='w', delete=False) for manifest_path in args.manifest_paths: - dump_text_manifest(fp, manifest_path) + text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys + for text_key in text_keys: + dump_text_manifest(fp, manifest_path, key=text_key) fp.close() # train spm.SentencePieceTrainer.Train( diff --git a/utils/format_triplet_data.py b/utils/format_triplet_data.py new file mode 100755 index 00000000..f3dd7ca4 --- /dev/null +++ b/utils/format_triplet_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json + +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.frontend.utility import load_cmvn +from deepspeech.frontend.utility import read_manifest +from deepspeech.utils.utility import add_arguments +from deepspeech.utils.utility import print_arguments + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") +add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.json', + "Filepath of cmvn.") +add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") +add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath of the vocabulary.") +add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) +# bpe +add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") +add_arg('output_path', str, None, "filepath of formated manifest.", required=True) +# yapf: disable +args = parser.parse_args() + + +def main(): + print_arguments(args, globals()) + fout = open(args.output_path, 'w', encoding='utf-8') + + # get feat dim + mean, std = load_cmvn(args.cmvn_path, filetype='json') + feat_dim = mean.shape[0] #(D) + print(f"Feature dim: {feat_dim}") + + text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) + vocab_size = text_feature.vocab_size + print(f"Vocab size: {vocab_size}") + + count = 0 + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + # text: translation text, text1: transcript text. + # Currently only support joint-vocab, will add separate vocabs setting. + line = line_json['text'] + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + line_json['token'] = tokens + line_json['token_id'] = tokenids + line_json['token_shape'] = (len(tokenids), vocab_size) + line = line_json['text1'] + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + line_json['token1'] = tokens + line_json['token_id1'] = tokenids + line_json['token_shape1'] = (len(tokenids), vocab_size) + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + if args.feat_type == 'raw': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplementedError('no support kaldi feat now!') + fout.write(json.dumps(line_json) + '\n') + count += 1 + + print(f"Examples number: {count}") + fout.close() + + +if __name__ == '__main__': + main()