fix ctc align

pull/1012/head
Hui Zhang 3 years ago
parent 69055698a2
commit 69bccb4f02

@ -575,7 +575,7 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align( ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size, self.model, self.align_loader, self.config.decoding.batch_size,
self.config.collator.stride_ms, self.config.collator.stride_ms,
self.vocab_list, self.args.result_file) self.vocab_list, self.args.result_file)

@ -528,7 +528,7 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align( ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size, self.model, self.align_loader, self.config.decoding.batch_size,
self.config.collator.stride_ms, self.config.collator.stride_ms,
self.vocab_list, self.args.result_file) self.vocab_list, self.args.result_file)

@ -543,10 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align( ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size, self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms, self.config.collator.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file) self.vocab_list, self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import List from typing import List
from pathlib import Path
import numpy as np import numpy as np
import paddle import paddle
@ -139,26 +139,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
return output_alignment return output_alignment
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict, def ctc_align(config, model, dataloader, batch_size, stride_ms, token_dict,
result_file): result_file):
"""ctc alignment. """ctc alignment.
Args: Args:
config (cfgNode): config
model (nn.Layer): U2 Model. model (nn.Layer): U2 Model.
dataloader (io.DataLoader): dataloader. dataloader (io.DataLoader): dataloader.
batch_size (int): decoding batchsize. batch_size (int): decoding batchsize.
stride_ms (int): audio feature stride in ms unit. stride_ms (int): audio feature stride in ms unit.
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>']. token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
result_file (str): alignment output file, e.g. xxx.align. result_file (str): alignment output file, e.g. /path/to/xxx.align.
""" """
if batch_size > 1: if batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1') logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1) sys.exit(1)
assert result_file and result_file.endswith('.align') assert result_file and result_file.endswith('.align')
model.eval() model.eval()
# conv subsampling rate
subsample = utility.get_subsample(config)
logger.info(f"Align Total Examples: {len(dataloader.dataset)}") logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
with open(result_file, 'w') as fout: with open(result_file, 'w') as fout:
@ -187,13 +188,11 @@ def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
logger.info(f"align tokens: {key[0]}, {align_segs}") logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"] # IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(align_segs, subsample, tierformat = text_grid.align_to_tierformat(align_segs, subsample,
token_dict) token_dict)
# write tier # write tier
align_output_path = Path(self.args.result_file).parent / "align" align_output_path = Path(result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True) align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier") tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f: with tier_path.open('w') as f:

Loading…
Cancel
Save