|
|
@ -13,7 +13,7 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
@ -139,26 +139,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
|
|
|
|
return output_alignment
|
|
|
|
return output_alignment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
|
|
|
|
def ctc_align(config, model, dataloader, batch_size, stride_ms, token_dict,
|
|
|
|
result_file):
|
|
|
|
result_file):
|
|
|
|
"""ctc alignment.
|
|
|
|
"""ctc alignment.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
|
|
|
|
config (cfgNode): config
|
|
|
|
model (nn.Layer): U2 Model.
|
|
|
|
model (nn.Layer): U2 Model.
|
|
|
|
dataloader (io.DataLoader): dataloader.
|
|
|
|
dataloader (io.DataLoader): dataloader.
|
|
|
|
batch_size (int): decoding batchsize.
|
|
|
|
batch_size (int): decoding batchsize.
|
|
|
|
stride_ms (int): audio feature stride in ms unit.
|
|
|
|
stride_ms (int): audio feature stride in ms unit.
|
|
|
|
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
|
|
|
|
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
|
|
|
|
result_file (str): alignment output file, e.g. xxx.align.
|
|
|
|
result_file (str): alignment output file, e.g. /path/to/xxx.align.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if batch_size > 1:
|
|
|
|
if batch_size > 1:
|
|
|
|
logger.fatal('alignment mode must be running with batch_size == 1')
|
|
|
|
logger.fatal('alignment mode must be running with batch_size == 1')
|
|
|
|
sys.exit(1)
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
assert result_file and result_file.endswith('.align')
|
|
|
|
assert result_file and result_file.endswith('.align')
|
|
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
# conv subsampling rate
|
|
|
|
|
|
|
|
subsample = utility.get_subsample(config)
|
|
|
|
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
|
|
|
|
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
|
|
|
|
|
|
|
|
|
|
|
|
with open(result_file, 'w') as fout:
|
|
|
|
with open(result_file, 'w') as fout:
|
|
|
@ -187,13 +188,11 @@ def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
|
|
|
|
logger.info(f"align tokens: {key[0]}, {align_segs}")
|
|
|
|
logger.info(f"align tokens: {key[0]}, {align_segs}")
|
|
|
|
|
|
|
|
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
subsample = utility.get_subsample(self.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
|
|
|
|
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
|
|
|
|
token_dict)
|
|
|
|
token_dict)
|
|
|
|
|
|
|
|
|
|
|
|
# write tier
|
|
|
|
# write tier
|
|
|
|
align_output_path = Path(self.args.result_file).parent / "align"
|
|
|
|
align_output_path = Path(result_file).parent / "align"
|
|
|
|
align_output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
align_output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
tier_path = align_output_path / (key[0] + ".tier")
|
|
|
|
tier_path = align_output_path / (key[0] + ".tier")
|
|
|
|
with tier_path.open('w') as f:
|
|
|
|
with tier_path.open('w') as f:
|
|
|
|