Merge pull request #629 from PaddlePaddle/align

ctc alignment
pull/695/head
Hui Zhang 4 years ago committed by GitHub
commit 43b52082c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -278,7 +281,15 @@ class U2Trainer(Trainer):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator.from_config(config)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!") # return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
@ -498,6 +509,73 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(-1) sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = [] new_hyp: List[int] = []
cur = 0 cur = 0
while cur < len(hyp): while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id: if hyp[cur] != blank_id:
new_hyp.append(hyp[cur]) new_hyp.append(hyp[cur])
# skip repeat label
prev = cur prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]: while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1 cur += 1
return new_hyp return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0): def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token. """Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-" "abcdefg" -> "-a-b-c-d-e-f-g-"
Args: Args:
label ([np.ndarray]): label ids, (L). label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0. blank_id (int, optional): blank id. Defaults to 0.
Returns: Returns:
@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
label = np.expand_dims(label, 1) #[L, 1] label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2] label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L] label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1] label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list: blank_id=0) -> List[int]:
"""ctc forced alignment. """ctc forced alignment.
https://distill.pub/2017/ctc/ https://distill.pub/2017/ctc/
@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L) y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index blank_id (int): blank symbol index
Returns: Returns:
paddle.Tensor: best alignment result, (T). List[int]: best alignment result, (T).
""" """
y_insert_blank = insert_blank(y, blank_id) y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros( log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros( state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path ) # state path, Tuple((T, 2L+1))
# init start state # init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]: s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor( candidates = paddle.to_tensor(
@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2], log_alpha[t - 1, s - 2],
]) ])
prev_state = [s, s - 1, s - 2] prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
y_insert_blank[s]] log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)] state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) # TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([ candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb log_alpha[-1, len(y_insert_blank) - 1], # Sb

@ -0,0 +1,127 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))

@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args) a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args)) lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -30,10 +30,15 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -19,7 +19,7 @@ kenlm.done:
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50 apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done touch kenlm.done
sox.done: sox.done:

Loading…
Cancel
Save