diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 6b23a9852..2f0e752f8 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -575,7 +575,7 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align( + ctc_utils.ctc_align(self.config, self.model, self.align_loader, self.config.decoding.batch_size, self.config.collator.stride_ms, self.vocab_list, self.args.result_file) diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index 357a39b91..6c4365b86 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -528,7 +528,7 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align( + ctc_utils.ctc_align(self.config, self.model, self.align_loader, self.config.decoding.batch_size, self.config.collator.stride_ms, self.vocab_list, self.args.result_file) diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index f458216e3..9141b3613 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -543,10 +543,10 @@ class U2STTester(U2STTrainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align( + ctc_utils.ctc_align(self.config, self.model, self.align_loader, self.config.decoding.batch_size, - self.align_loader.collate_fn.stride_ms, - self.align_loader.collate_fn.vocab_list, self.args.result_file) + self.config.collator.stride_ms, + self.vocab_list, self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/utils/ctc_utils.py b/paddlespeech/s2t/utils/ctc_utils.py index e005e5d28..f5822e5dd 100644 --- a/paddlespeech/s2t/utils/ctc_utils.py +++ b/paddlespeech/s2t/utils/ctc_utils.py @@ -13,7 +13,7 @@ # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) from typing import List - +from pathlib import Path import numpy as np import paddle @@ -139,26 +139,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, return output_alignment -def ctc_align(model, dataloader, batch_size, stride_ms, token_dict, +def ctc_align(config, model, dataloader, batch_size, stride_ms, token_dict, result_file): """ctc alignment. Args: + config (cfgNode): config model (nn.Layer): U2 Model. dataloader (io.DataLoader): dataloader. batch_size (int): decoding batchsize. stride_ms (int): audio feature stride in ms unit. token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '']. - result_file (str): alignment output file, e.g. xxx.align. + result_file (str): alignment output file, e.g. /path/to/xxx.align. """ if batch_size > 1: logger.fatal('alignment mode must be running with batch_size == 1') sys.exit(1) - assert result_file and result_file.endswith('.align') model.eval() - + # conv subsampling rate + subsample = utility.get_subsample(config) logger.info(f"Align Total Examples: {len(dataloader.dataset)}") with open(result_file, 'w') as fout: @@ -187,13 +188,11 @@ def ctc_align(model, dataloader, batch_size, stride_ms, token_dict, logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] - subsample = utility.get_subsample(self.config) - tierformat = text_grid.align_to_tierformat(align_segs, subsample, token_dict) # write tier - align_output_path = Path(self.args.result_file).parent / "align" + align_output_path = Path(result_file).parent / "align" align_output_path.mkdir(parents=True, exist_ok=True) tier_path = align_output_path / (key[0] + ".tier") with tier_path.open('w') as f: