diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 3ebbbe7a..92320dac 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -16,7 +16,6 @@ import os import time from collections import defaultdict from contextlib import nullcontext -from pathlib import Path from typing import Optional import jsonlines @@ -386,6 +385,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): logger.info(msg) self.autolog.report() + @paddle.no_grad() def export(self): if self.args.model_type == 'offline': infer_model = DeepSpeech2InferModel.from_pretrained( diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index beb91d5d..0976ec1a 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -14,12 +14,10 @@ """Contains U2 model.""" import json import os -import sys import time from collections import defaultdict from collections import OrderedDict from contextlib import nullcontext -from pathlib import Path from typing import Optional import jsonlines @@ -44,8 +42,6 @@ from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools -from deepspeech.utils import text_grid -from deepspeech.utils import utility from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig @@ -553,62 +549,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - if self.config.decoding.batch_size > 1: - logger.fatal('alignment mode must be running with batch_size == 1') - sys.exit(1) - - # xxx.align - assert self.args.result_file and self.args.result_file.endswith( - '.align') - - self.model.eval() - logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - - stride_ms = self.align_loader.collate_fn.stride_ms - token_dict = self.align_loader.collate_fn.vocab_list - with open(self.args.result_file, 'w') as fout: - # one example in batch - for i, batch in enumerate(self.align_loader): - key, feat, feats_length, target, target_length = batch - - # 1. Encoder - encoder_out, encoder_mask = self.model._forward_encoder( - feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.shape[1] - ctc_probs = self.model.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - - # 2. alignment - ctc_probs = ctc_probs.squeeze(0) - target = target.squeeze(0) - alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info(f"align ids: {key[0]} {alignment}") - fout.write('{} {}\n'.format(key[0], alignment)) - - # 3. gen praat - # segment alignment - align_segs = text_grid.segment_alignment(alignment) - logger.info(f"align tokens: {key[0]}, {align_segs}") - # IntervalTier, List["start end token\n"] - subsample = utility.get_subsample(self.config) - tierformat = text_grid.align_to_tierformat( - align_segs, subsample, token_dict) - # write tier - align_output_path = Path(self.args.result_file).parent / "align" - align_output_path.mkdir(parents=True, exist_ok=True) - tier_path = align_output_path / (key[0] + ".tier") - with tier_path.open('w') as f: - f.writelines(tierformat) - # write textgrid - textgrid_path = align_output_path / (key[0] + ".TextGrid") - second_per_frame = 1. / (1000. / - stride_ms) # 25ms window, 10ms stride - second_per_example = ( - len(alignment) + 1) * subsample * second_per_frame - text_grid.generate_textgrid( - maxtime=second_per_example, - intervals=tierformat, - output=str(textgrid_path)) + ctc_utils.ctc_align( + self.model, self.align_loader, self.config.decoding.batch_size, + self.align_loader.collate_fn.stride_ms, + self.align_loader.collate_fn.vocab_list, self.args.result_file) def load_inferspec(self): """infer model and input spec. @@ -630,6 +574,7 @@ class U2Tester(U2Trainer): ] return infer_model, input_spec + @paddle.no_grad() def export(self): infer_model, input_spec = self.load_inferspec() assert isinstance(input_spec, list), type(input_spec) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 48950fc8..0151e208 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -14,11 +14,9 @@ """Contains U2 model.""" import json import os -import sys import time from collections import defaultdict from contextlib import nullcontext -from pathlib import Path from typing import Optional import jsonlines @@ -39,8 +37,6 @@ from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools -from deepspeech.utils import text_grid -from deepspeech.utils import utility from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig @@ -527,62 +523,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - if self.config.decoding.batch_size > 1: - logger.fatal('alignment mode must be running with batch_size == 1') - sys.exit(1) - - # xxx.align - assert self.args.result_file and self.args.result_file.endswith( - '.align') - - self.model.eval() - logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - - stride_ms = self.align_loader.collate_fn.stride_ms - token_dict = self.align_loader.collate_fn.vocab_list - with open(self.args.result_file, 'w') as fout: - # one example in batch - for i, batch in enumerate(self.align_loader): - key, feat, feats_length, target, target_length = batch - - # 1. Encoder - encoder_out, encoder_mask = self.model._forward_encoder( - feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.shape[1] - ctc_probs = self.model.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - - # 2. alignment - ctc_probs = ctc_probs.squeeze(0) - target = target.squeeze(0) - alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info(f"align ids: {key[0]} {alignment}") - fout.write('{} {}\n'.format(key[0], alignment)) - - # 3. gen praat - # segment alignment - align_segs = text_grid.segment_alignment(alignment) - logger.info(f"align tokens: {key[0]}, {align_segs}") - # IntervalTier, List["start end token\n"] - subsample = utility.get_subsample(self.config) - tierformat = text_grid.align_to_tierformat( - align_segs, subsample, token_dict) - # write tier - align_output_path = Path(self.args.result_file).parent / "align" - align_output_path.mkdir(parents=True, exist_ok=True) - tier_path = align_output_path / (key[0] + ".tier") - with tier_path.open('w') as f: - f.writelines(tierformat) - # write textgrid - textgrid_path = align_output_path / (key[0] + ".TextGrid") - second_per_frame = 1. / (1000. / - stride_ms) # 25ms window, 10ms stride - second_per_example = ( - len(alignment) + 1) * subsample * second_per_frame - text_grid.generate_textgrid( - maxtime=second_per_example, - intervals=tierformat, - output=str(textgrid_path)) + ctc_utils.ctc_align( + self.model, self.align_loader, self.config.decoding.batch_size, + self.align_loader.collate_fn.stride_ms, + self.align_loader.collate_fn.vocab_list, self.args.result_file) def load_inferspec(self): """infer model and input spec. @@ -604,6 +548,7 @@ class U2Tester(U2Trainer): ] return infer_model, input_spec + @paddle.no_grad() def export(self): infer_model, input_spec = self.load_inferspec() assert isinstance(input_spec, list), type(input_spec) diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 2d228d29..c5df44c6 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -14,11 +14,9 @@ """Contains U2 model.""" import json import os -import sys import time from collections import defaultdict from contextlib import nullcontext -from pathlib import Path from typing import Optional import jsonlines @@ -42,8 +40,6 @@ from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools -from deepspeech.utils import text_grid -from deepspeech.utils import utility from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig @@ -547,62 +543,10 @@ class U2STTester(U2STTrainer): @paddle.no_grad() def align(self): - if self.config.decoding.batch_size > 1: - logger.fatal('alignment mode must be running with batch_size == 1') - sys.exit(1) - - # xxx.align - assert self.args.result_file and self.args.result_file.endswith( - '.align') - - self.model.eval() - logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - - stride_ms = self.align_loader.collate_fn.stride_ms - token_dict = self.align_loader.collate_fn.vocab_list - with open(self.args.result_file, 'w') as fout: - # one example in batch - for i, batch in enumerate(self.align_loader): - key, feat, feats_length, target, target_length = batch - - # 1. Encoder - encoder_out, encoder_mask = self.model._forward_encoder( - feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.shape[1] - ctc_probs = self.model.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - - # 2. alignment - ctc_probs = ctc_probs.squeeze(0) - target = target.squeeze(0) - alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info(f"align ids: {key[0]} {alignment}") - fout.write('{} {}\n'.format(key[0], alignment)) - - # 3. gen praat - # segment alignment - align_segs = text_grid.segment_alignment(alignment) - logger.info(f"align tokens: {key[0]}, {align_segs}") - # IntervalTier, List["start end token\n"] - subsample = utility.get_subsample(self.config) - tierformat = text_grid.align_to_tierformat( - align_segs, subsample, token_dict) - # write tier - align_output_path = Path(self.args.result_file).parent / "align" - align_output_path.mkdir(parents=True, exist_ok=True) - tier_path = align_output_path / (key[0] + ".tier") - with tier_path.open('w') as f: - f.writelines(tierformat) - # write textgrid - textgrid_path = align_output_path / (key[0] + ".TextGrid") - second_per_frame = 1. / (1000. / - stride_ms) # 25ms window, 10ms stride - second_per_example = ( - len(alignment) + 1) * subsample * second_per_frame - text_grid.generate_textgrid( - maxtime=second_per_example, - intervals=tierformat, - output=str(textgrid_path)) + ctc_utils.ctc_align( + self.model, self.align_loader, self.config.decoding.batch_size, + self.align_loader.collate_fn.stride_ms, + self.align_loader.collate_fn.vocab_list, self.args.result_file) def load_inferspec(self): """infer model and input spec. @@ -624,6 +568,7 @@ class U2STTester(U2STTrainer): ] return infer_model, input_spec + @paddle.no_grad() def export(self): infer_model, input_spec = self.load_inferspec() assert isinstance(input_spec, list), type(input_spec) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 70d99e6c..7e8629c2 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -16,6 +16,8 @@ from typing import List import numpy as np import paddle +from deepspeech.utils import text_grid +from deepspeech.utils import utility from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -134,3 +136,85 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment + + +# ctc_align( +# self.model, +# self.align_loader, +# self.config.decoding.batch_size, +# self.align_loader.collate_fn.stride_ms, +# self.align_loader.collate_fn.vocab_list, +# self.args.result_file, +# ) + + +def ctc_align(model, dataloader, batch_size, stride_ms, token_dict, + result_file): + """ctc alignment. + + Args: + model (nn.Layer): U2 Model. + dataloader (io.DataLoader): dataloader. + batch_size (int): decoding batchsize. + stride_ms (int): audio feature stride in ms unit. + token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '']. + result_file (str): alignment output file, e.g. xxx.align. + """ + if batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + assert result_file and result_file.endswith('.align') + + model.eval() + + logger.info(f"Align Total Examples: {len(dataloader.dataset)}") + + with open(result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(dataloader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.shape[1] + ctc_probs = model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = forced_align(ctc_probs, target) + + logger.info(f"align ids: {key[0]} {alignment}") + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info(f"align tokens: {key[0]}, {align_segs}") + + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + + tierformat = text_grid.align_to_tierformat(align_segs, subsample, + token_dict) + + # write tier + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: + f.writelines(tierformat) + + # write textgrid + textgrid_path = align_output_path / (key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=str(textgrid_path))