add align code

pull/629/head
Hui Zhang 3 years ago
parent 9cc750bf29
commit 30aba26693

@ -34,9 +34,11 @@ from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
@ -483,6 +485,67 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file
self.model.eval()
logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms
token_dict = self.test_loader.dataset.vocab_list
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
# print(ctc_probs.size(1))
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
print(alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
print(align_segs)
# IntervalTier, List["start end token\n"]
subsample = get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
tier_path = os.path.join(
os.path.dirname(args.result_file), key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
textgrid_path = s.path.join(
os.path.dirname(args.result_file), key[0] + ".TextGrid")
second_per_frame = 1. / (1000. / stride_ms
) # 25ms window, 10ms stride
text_grid.generate_textgrid(
maxtime=(len(alignment) + 1) * subsample * second_per_frame,
lines=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.

@ -46,7 +46,7 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
@ -67,7 +67,7 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
blank_id=0) -> List[int]:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
@ -77,7 +77,7 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
List[int]: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)

@ -0,0 +1,125 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence. e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: segment aligment id sequence. e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval.
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average duration per {name}: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))
Loading…
Cancel
Save