|
|
|
@ -14,12 +14,10 @@
|
|
|
|
|
"""Contains U2 model."""
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import jsonlines
|
|
|
|
@ -44,8 +42,6 @@ from deepspeech.utils import ctc_utils
|
|
|
|
|
from deepspeech.utils import error_rate
|
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
|
from deepspeech.utils import text_grid
|
|
|
|
|
from deepspeech.utils import utility
|
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
from deepspeech.utils.utility import UpdateConfig
|
|
|
|
|
|
|
|
|
@ -553,62 +549,10 @@ class U2Tester(U2Trainer):
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def align(self):
|
|
|
|
|
if self.config.decoding.batch_size > 1:
|
|
|
|
|
logger.fatal('alignment mode must be running with batch_size == 1')
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
# xxx.align
|
|
|
|
|
assert self.args.result_file and self.args.result_file.endswith(
|
|
|
|
|
'.align')
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.align_loader.collate_fn.stride_ms
|
|
|
|
|
token_dict = self.align_loader.collate_fn.vocab_list
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
# one example in batch
|
|
|
|
|
for i, batch in enumerate(self.align_loader):
|
|
|
|
|
key, feat, feats_length, target, target_length = batch
|
|
|
|
|
|
|
|
|
|
# 1. Encoder
|
|
|
|
|
encoder_out, encoder_mask = self.model._forward_encoder(
|
|
|
|
|
feat, feats_length) # (B, maxlen, encoder_dim)
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(
|
|
|
|
|
encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
|
|
|
|
|
|
# 2. alignment
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
target = target.squeeze(0)
|
|
|
|
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
|
|
|
|
logger.info(f"align ids: {key[0]} {alignment}")
|
|
|
|
|
fout.write('{} {}\n'.format(key[0], alignment))
|
|
|
|
|
|
|
|
|
|
# 3. gen praat
|
|
|
|
|
# segment alignment
|
|
|
|
|
align_segs = text_grid.segment_alignment(alignment)
|
|
|
|
|
logger.info(f"align tokens: {key[0]}, {align_segs}")
|
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
|
subsample = utility.get_subsample(self.config)
|
|
|
|
|
tierformat = text_grid.align_to_tierformat(
|
|
|
|
|
align_segs, subsample, token_dict)
|
|
|
|
|
# write tier
|
|
|
|
|
align_output_path = Path(self.args.result_file).parent / "align"
|
|
|
|
|
align_output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
tier_path = align_output_path / (key[0] + ".tier")
|
|
|
|
|
with tier_path.open('w') as f:
|
|
|
|
|
f.writelines(tierformat)
|
|
|
|
|
# write textgrid
|
|
|
|
|
textgrid_path = align_output_path / (key[0] + ".TextGrid")
|
|
|
|
|
second_per_frame = 1. / (1000. /
|
|
|
|
|
stride_ms) # 25ms window, 10ms stride
|
|
|
|
|
second_per_example = (
|
|
|
|
|
len(alignment) + 1) * subsample * second_per_frame
|
|
|
|
|
text_grid.generate_textgrid(
|
|
|
|
|
maxtime=second_per_example,
|
|
|
|
|
intervals=tierformat,
|
|
|
|
|
output=str(textgrid_path))
|
|
|
|
|
ctc_utils.ctc_align(
|
|
|
|
|
self.model, self.align_loader, self.config.decoding.batch_size,
|
|
|
|
|
self.align_loader.collate_fn.stride_ms,
|
|
|
|
|
self.align_loader.collate_fn.vocab_list, self.args.result_file)
|
|
|
|
|
|
|
|
|
|
def load_inferspec(self):
|
|
|
|
|
"""infer model and input spec.
|
|
|
|
@ -630,6 +574,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
]
|
|
|
|
|
return infer_model, input_spec
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def export(self):
|
|
|
|
|
infer_model, input_spec = self.load_inferspec()
|
|
|
|
|
assert isinstance(input_spec, list), type(input_spec)
|
|
|
|
|