|
|
|
@ -355,7 +355,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
|
|
|
|
|
# <0: for decoding, use full chunk.
|
|
|
|
|
# >0: for decoding, use fixed chunk size as set.
|
|
|
|
|
# 0: used for training, it's prohibited here.
|
|
|
|
|
# 0: used for training, it's prohibited here.
|
|
|
|
|
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
|
|
|
|
|
simulate_streaming=False, # simulate streaming inference. Defaults to False.
|
|
|
|
|
))
|
|
|
|
@ -512,11 +512,13 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.test_loader.dataset.stride_ms
|
|
|
|
|
token_dict = self.test_loader.dataset.vocab_list
|
|
|
|
|
stride_ms = self.test_loader.collate_fn.stride_ms
|
|
|
|
|
token_dict = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
# one example in batch
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
key, feat, feats_length, target, target_length = batch
|
|
|
|
|
|
|
|
|
|
# 1. Encoder
|
|
|
|
|
encoder_out, encoder_mask = self.model._forward_encoder(
|
|
|
|
|
feat, feats_length) # (B, maxlen, encoder_dim)
|
|
|
|
@ -529,28 +531,31 @@ class U2Tester(U2Trainer):
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
target = target.squeeze(0)
|
|
|
|
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
|
|
|
|
print(alignment)
|
|
|
|
|
print(kye[0], alignment)
|
|
|
|
|
fout.write('{} {}\n'.format(key[0], alignment))
|
|
|
|
|
|
|
|
|
|
# 3. gen praat
|
|
|
|
|
# segment alignment
|
|
|
|
|
align_segs = text_grid.segment_alignment(alignment)
|
|
|
|
|
print(align_segs)
|
|
|
|
|
print(kye[0], align_segs)
|
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
|
subsample = get_subsample(self.config)
|
|
|
|
|
tierformat = text_grid.align_to_tierformat(
|
|
|
|
|
align_segs, subsample, token_dict)
|
|
|
|
|
# write tier
|
|
|
|
|
tier_path = os.path.join(
|
|
|
|
|
os.path.dirname(args.result_file), key[0] + ".tier")
|
|
|
|
|
with open(tier_path, 'w') as f:
|
|
|
|
|
f.writelines(tierformat)
|
|
|
|
|
|
|
|
|
|
# write textgrid
|
|
|
|
|
textgrid_path = s.path.join(
|
|
|
|
|
os.path.dirname(args.result_file), key[0] + ".TextGrid")
|
|
|
|
|
second_per_frame = 1. / (1000. / stride_ms
|
|
|
|
|
) # 25ms window, 10ms stride
|
|
|
|
|
second_per_frame = 1. / (1000. /
|
|
|
|
|
stride_ms) # 25ms window, 10ms stride
|
|
|
|
|
second_per_example = (
|
|
|
|
|
len(alignment) + 1) * subsample * second_per_frame
|
|
|
|
|
text_grid.generate_textgrid(
|
|
|
|
|
maxtime=(len(alignment) + 1) * subsample * second_per_frame,
|
|
|
|
|
maxtime=second_per_example,
|
|
|
|
|
lines=tierformat,
|
|
|
|
|
output=textgrid_path)
|
|
|
|
|
|
|
|
|
|