|
|
|
@ -545,9 +545,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.config.collater.stride_ms
|
|
|
|
|
token_dict = self.args.char_list
|
|
|
|
|
|
|
|
|
|
stride_ms = self.align_loader.collate_fn.stride_ms
|
|
|
|
|
token_dict = self.align_loader.collate_fn.vocab_list
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
# one example in batch
|
|
|
|
|
for i, batch in enumerate(self.align_loader):
|
|
|
|
@ -564,26 +563,25 @@ class U2Tester(U2Trainer):
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
target = target.squeeze(0)
|
|
|
|
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
|
|
|
|
logger.info("align ids", key[0], alignment)
|
|
|
|
|
logger.info(f"align ids: {key[0]} {alignment}")
|
|
|
|
|
fout.write('{} {}\n'.format(key[0], alignment))
|
|
|
|
|
|
|
|
|
|
# 3. gen praat
|
|
|
|
|
# segment alignment
|
|
|
|
|
align_segs = text_grid.segment_alignment(alignment)
|
|
|
|
|
logger.info("align tokens", key[0], align_segs)
|
|
|
|
|
logger.info(f"align tokens: {key[0]}, {align_segs}")
|
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
|
subsample = utility.get_subsample(self.config)
|
|
|
|
|
tierformat = text_grid.align_to_tierformat(
|
|
|
|
|
align_segs, subsample, token_dict)
|
|
|
|
|
# write tier
|
|
|
|
|
align_output_path = os.path.join(
|
|
|
|
|
os.path.dirname(self.args.result_file), "align")
|
|
|
|
|
tier_path = os.path.join(align_output_path, key[0] + ".tier")
|
|
|
|
|
with open(tier_path, 'w') as f:
|
|
|
|
|
align_output_path = Path(self.args.result_file).parent / "align"
|
|
|
|
|
align_output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
tier_path = align_output_path / (key[0] + ".tier")
|
|
|
|
|
with tier_path.open('w') as f:
|
|
|
|
|
f.writelines(tierformat)
|
|
|
|
|
# write textgrid
|
|
|
|
|
textgrid_path = os.path.join(align_output_path,
|
|
|
|
|
key[0] + ".TextGrid")
|
|
|
|
|
textgrid_path = align_output_path / (key[0] + ".TextGrid")
|
|
|
|
|
second_per_frame = 1. / (1000. /
|
|
|
|
|
stride_ms) # 25ms window, 10ms stride
|
|
|
|
|
second_per_example = (
|
|
|
|
@ -591,7 +589,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
text_grid.generate_textgrid(
|
|
|
|
|
maxtime=second_per_example,
|
|
|
|
|
intervals=tierformat,
|
|
|
|
|
output=textgrid_path)
|
|
|
|
|
output=str(textgrid_path))
|
|
|
|
|
|
|
|
|
|
def run_align(self):
|
|
|
|
|
self.resume_or_scratch()
|
|
|
|
|