extract ctc_align into ctc_utils

pull/879/head
Hui Zhang 3 years ago
parent 8b45c3e65e
commit dd06472432

@ -16,7 +16,6 @@ import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines import jsonlines
@ -386,6 +385,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(msg) logger.info(msg)
self.autolog.report() self.autolog.report()
@paddle.no_grad()
def export(self): def export(self):
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(

@ -14,12 +14,10 @@
"""Contains U2 model.""" """Contains U2 model."""
import json import json
import os import os
import sys
import time import time
from collections import defaultdict from collections import defaultdict
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines import jsonlines
@ -44,8 +42,6 @@ from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
@ -553,62 +549,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: ctc_utils.ctc_align(
logger.fatal('alignment mode must be running with batch_size == 1') self.model, self.align_loader, self.config.decoding.batch_size,
sys.exit(1) self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -630,6 +574,7 @@ class U2Tester(U2Trainer):
] ]
return infer_model, input_spec return infer_model, input_spec
@paddle.no_grad()
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec) assert isinstance(input_spec, list), type(input_spec)

@ -14,11 +14,9 @@
"""Contains U2 model.""" """Contains U2 model."""
import json import json
import os import os
import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines import jsonlines
@ -39,8 +37,6 @@ from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
@ -527,62 +523,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: ctc_utils.ctc_align(
logger.fatal('alignment mode must be running with batch_size == 1') self.model, self.align_loader, self.config.decoding.batch_size,
sys.exit(1) self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -604,6 +548,7 @@ class U2Tester(U2Trainer):
] ]
return infer_model, input_spec return infer_model, input_spec
@paddle.no_grad()
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec) assert isinstance(input_spec, list), type(input_spec)

@ -14,11 +14,9 @@
"""Contains U2 model.""" """Contains U2 model."""
import json import json
import os import os
import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines import jsonlines
@ -42,8 +40,6 @@ from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
@ -547,62 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: ctc_utils.ctc_align(
logger.fatal('alignment mode must be running with batch_size == 1') self.model, self.align_loader, self.config.decoding.batch_size,
sys.exit(1) self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -624,6 +568,7 @@ class U2STTester(U2STTrainer):
] ]
return infer_model, input_spec return infer_model, input_spec
@paddle.no_grad()
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec) assert isinstance(input_spec, list), type(input_spec)

@ -16,6 +16,8 @@ from typing import List
import numpy as np import numpy as np
import paddle import paddle
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -134,3 +136,85 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
output_alignment.append(y_insert_blank[state_seq[t, 0]]) output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment return output_alignment
# ctc_align(
# self.model,
# self.align_loader,
# self.config.decoding.batch_size,
# self.align_loader.collate_fn.stride_ms,
# self.align_loader.collate_fn.vocab_list,
# self.args.result_file,
# )
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
result_file):
"""ctc alignment.
Args:
model (nn.Layer): U2 Model.
dataloader (io.DataLoader): dataloader.
batch_size (int): decoding batchsize.
stride_ms (int): audio feature stride in ms unit.
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
result_file (str): alignment output file, e.g. xxx.align.
"""
if batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
assert result_file and result_file.endswith('.align')
model.eval()
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
with open(result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(dataloader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
ctc_probs = model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = forced_align(ctc_probs, target)
logger.info(f"align ids: {key[0]} {alignment}")
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = align_output_path / (key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=str(textgrid_path))

Loading…
Cancel
Save