|
|
@ -39,6 +39,7 @@ from deepspeech.utils import error_rate
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
from deepspeech.utils import text_grid
|
|
|
|
from deepspeech.utils import text_grid
|
|
|
|
|
|
|
|
from deepspeech.utils import utility
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
@ -280,7 +281,15 @@ class U2Trainer(Trainer):
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
# return text token id
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
|
|
|
|
self.align_loader = DataLoader(
|
|
|
|
|
|
|
|
test_dataset,
|
|
|
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
|
|
|
shuffle=False,
|
|
|
|
|
|
|
|
drop_last=False,
|
|
|
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
|
|
|
|
logger.info("Setup train/valid/test/align Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
def setup_model(self):
|
|
|
|
config = self.config
|
|
|
|
config = self.config
|
|
|
@ -507,16 +516,17 @@ class U2Tester(U2Trainer):
|
|
|
|
sys.exit(1)
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
# xxx.align
|
|
|
|
# xxx.align
|
|
|
|
assert self.args.result_file
|
|
|
|
assert self.args.result_file and self.args.result_file.endswith(
|
|
|
|
|
|
|
|
'.align')
|
|
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
|
|
stride_ms = self.test_loader.collate_fn.stride_ms
|
|
|
|
stride_ms = self.align_loader.collate_fn.stride_ms
|
|
|
|
token_dict = self.test_loader.collate_fn.vocab_list
|
|
|
|
token_dict = self.align_loader.collate_fn.vocab_list
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
# one example in batch
|
|
|
|
# one example in batch
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
for i, batch in enumerate(self.align_loader):
|
|
|
|
key, feat, feats_length, target, target_length = batch
|
|
|
|
key, feat, feats_length, target, target_length = batch
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Encoder
|
|
|
|
# 1. Encoder
|
|
|
@ -527,36 +537,36 @@ class U2Tester(U2Trainer):
|
|
|
|
encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. alignment
|
|
|
|
# 2. alignment
|
|
|
|
# print(ctc_probs.size(1))
|
|
|
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
target = target.squeeze(0)
|
|
|
|
target = target.squeeze(0)
|
|
|
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
|
|
|
alignment = ctc_utils.forced_align(ctc_probs, target)
|
|
|
|
print(kye[0], alignment)
|
|
|
|
logger.info("align ids", key[0], alignment)
|
|
|
|
fout.write('{} {}\n'.format(key[0], alignment))
|
|
|
|
fout.write('{} {}\n'.format(key[0], alignment))
|
|
|
|
|
|
|
|
|
|
|
|
# 3. gen praat
|
|
|
|
# 3. gen praat
|
|
|
|
# segment alignment
|
|
|
|
# segment alignment
|
|
|
|
align_segs = text_grid.segment_alignment(alignment)
|
|
|
|
align_segs = text_grid.segment_alignment(alignment)
|
|
|
|
print(kye[0], align_segs)
|
|
|
|
logger.info("align tokens", key[0], align_segs)
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
# IntervalTier, List["start end token\n"]
|
|
|
|
subsample = get_subsample(self.config)
|
|
|
|
subsample = utility.get_subsample(self.config)
|
|
|
|
tierformat = text_grid.align_to_tierformat(
|
|
|
|
tierformat = text_grid.align_to_tierformat(
|
|
|
|
align_segs, subsample, token_dict)
|
|
|
|
align_segs, subsample, token_dict)
|
|
|
|
# write tier
|
|
|
|
# write tier
|
|
|
|
tier_path = os.path.join(
|
|
|
|
align_output_path = os.path.join(
|
|
|
|
os.path.dirname(args.result_file), key[0] + ".tier")
|
|
|
|
os.path.dirname(self.args.result_file), "align")
|
|
|
|
|
|
|
|
tier_path = os.path.join(align_output_path, key[0] + ".tier")
|
|
|
|
with open(tier_path, 'w') as f:
|
|
|
|
with open(tier_path, 'w') as f:
|
|
|
|
f.writelines(tierformat)
|
|
|
|
f.writelines(tierformat)
|
|
|
|
# write textgrid
|
|
|
|
# write textgrid
|
|
|
|
textgrid_path = s.path.join(
|
|
|
|
textgrid_path = os.path.join(align_output_path,
|
|
|
|
os.path.dirname(args.result_file), key[0] + ".TextGrid")
|
|
|
|
key[0] + ".TextGrid")
|
|
|
|
second_per_frame = 1. / (1000. /
|
|
|
|
second_per_frame = 1. / (1000. /
|
|
|
|
stride_ms) # 25ms window, 10ms stride
|
|
|
|
stride_ms) # 25ms window, 10ms stride
|
|
|
|
second_per_example = (
|
|
|
|
second_per_example = (
|
|
|
|
len(alignment) + 1) * subsample * second_per_frame
|
|
|
|
len(alignment) + 1) * subsample * second_per_frame
|
|
|
|
text_grid.generate_textgrid(
|
|
|
|
text_grid.generate_textgrid(
|
|
|
|
maxtime=second_per_example,
|
|
|
|
maxtime=second_per_example,
|
|
|
|
lines=tierformat,
|
|
|
|
intervals=tierformat,
|
|
|
|
output=textgrid_path)
|
|
|
|
output=textgrid_path)
|
|
|
|
|
|
|
|
|
|
|
|
def run_align(self):
|
|
|
|
def run_align(self):
|
|
|
|