diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 05a37b21b..dd62f537e 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR from deepspeech.training.trainer import Trainer +from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools +from deepspeech.utils import text_grid +from deepspeech.utils import utility from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -278,7 +281,15 @@ class U2Trainer(Trainer): shuffle=False, drop_last=False, collate_fn=SpeechCollator.from_config(config)) - logger.info("Setup train/valid/test Dataloader!") + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self): config = self.config @@ -353,7 +364,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. + # 0: used for training, it's prohibited here. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. simulate_streaming=False, # simulate streaming inference. Defaults to False. )) @@ -498,6 +509,73 @@ class U2Tester(U2Trainer): except KeyboardInterrupt: sys.exit(-1) + @paddle.no_grad() + def align(self): + if self.config.decoding.batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + # xxx.align + assert self.args.result_file and self.args.result_file.endswith( + '.align') + + self.model.eval() + logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") + + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list + with open(self.args.result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(self.align_loader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = self.model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = ctc_utils.forced_align(ctc_probs, target) + logger.info("align ids", key[0], alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info("align tokens", key[0], align_segs) + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + tierformat = text_grid.align_to_tierformat( + align_segs, subsample, token_dict) + # write tier + align_output_path = os.path.join( + os.path.dirname(self.args.result_file), "align") + tier_path = os.path.join(align_output_path, key[0] + ".tier") + with open(tier_path, 'w') as f: + f.writelines(tierformat) + # write textgrid + textgrid_path = os.path.join(align_output_path, + key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=textgrid_path) + + def run_align(self): + self.resume_or_scratch() + try: + self.align() + except KeyboardInterrupt: + sys.exit(-1) + def load_inferspec(self): """infer model and input spec. diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 73669fea6..09543d48d 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): + # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) + # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp -def insert_blank(label: np.ndarray, blank_id: int=0): +def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: """Insert blank token between every two label token. "abcdefg" -> "-a-b-c-d-e-f-g-" Args: - label ([np.ndarray]): label ids, (L). + label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: @@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0): label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] - label = label.reshape(-1) #[2L] - label = np.append(label, label[0]) #[2L + 1] + label = label.reshape(-1) #[2L], -l-l-l + label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, - blank_id=0) -> list: + blank_id=0) -> List[int]: """ctc forced alignment. https://distill.pub/2017/ctc/ @@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y (paddle.Tensor): label id sequence tensor, 1d tensor (L) blank_id (int): blank symbol index Returns: - paddle.Tensor: best alignment result, (T). + List[int]: best alignment result, (T). """ - y_insert_blank = insert_blank(y, blank_id) + y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero + # TODO(Hui Zhang): zeros not support paddle.int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 - ) # state path + (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + ) # state path, Tuple((T, 2L+1)) # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): + for t in range(1, ctc_probs.size(0)): # T + for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( @@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ - y_insert_blank[s]] + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( + y_insert_blank[s])] state_path[t, s] = prev_state[paddle.argmax(candidates)] - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) + # TODO(Hui Zhang): zeros not support paddle.int16 + state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb diff --git a/deepspeech/utils/text_grid.py b/deepspeech/utils/text_grid.py new file mode 100644 index 000000000..3af58c9ba --- /dev/null +++ b/deepspeech/utils/text_grid.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List +from typing import Text + +import textgrid + + +def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]: + """segment ctc alignment ids by continuous blank and repeat label. + + Args: + alignment (List[int]): ctc alignment id sequence. + e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[List[int]]: token align, segment aligment id sequence. + e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] + """ + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + align_segs = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == blank_id: # blank + end += 1 + if end == len(alignment): + align_segs[-1].extend(alignment[start:]) + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[ + end]: # repeat label + end += 1 + align_segs.append(alignment[start:end]) + start = end + return align_segs + + +def align_to_tierformat(align_segs: List[List[int]], + subsample: int, + token_dict: Dict[int, Text], + blank_id=0) -> List[Text]: + """Generate textgrid.Interval format from alignment segmentations. + + Args: + align_segs (List[List[int]]): segmented ctc alignment ids. + subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample + token_dict (Dict[int, Text]): int -> str map. + + Returns: + List[Text]: list of textgrid.Interval text, str(start, end, text). + """ + hop_length = 10 # ms + second_ms = 1000 # ms + frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length + second_per_frame = 1.0 / frame_per_second + + begin = 0 + duration = 0 + tierformat = [] + + for idx, tokens in enumerate(align_segs): + token_len = len(tokens) + token = tokens[-1] + # time duration in second + duration = token_len * subsample * second_per_frame + if idx < len(align_segs) - 1: + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + else: + for i in tokens: + if i != blank_id: + token = i + break + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + begin = begin + duration + + return tierformat + + +def generate_textgrid(maxtime: float, + intervals: List[Text], + output: Text, + name: Text='ali') -> None: + """Create alignment textgrid file. + + Args: + maxtime (float): audio duartion. + intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item. + output (Text): textgrid filepath. + name (Text, optional): tier or layer name. Defaults to 'ali'. + """ + # Download Praat: https://www.fon.hum.uva.nl/praat/ + avg_interval = maxtime / (len(intervals) + 1) + print(f"average second/token: {avg_interval}") + margin = 0.0001 + + tg = textgrid.TextGrid(maxTime=maxtime) + tier = textgrid.IntervalTier(name=name, maxTime=maxtime) + + i = 0 + for dur in intervals: + s, e, text = dur.split() + tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text) + + tg.append(tier) + + tg.write(output) + print("successfully generator textgrid {}.".format(output)) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 64570026b..a0639e065 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float: a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp + + +def get_subsample(config): + """Subsample rate from config. + + Args: + config (yacs.config.CfgNode): yaml config + + Returns: + int: subsample rate. + """ + input_layer = config["model"]["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if input_layer == "conv2d": + return 4 + elif input_layer == "conv2d6": + return 6 + elif input_layer == "conv2d8": + return 8 diff --git a/examples/aishell/s1/local/align.sh b/examples/aishell/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/aishell/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 4cf09553b..562cfa04d 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -30,10 +30,15 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/librispeech/s1/local/align.sh b/examples/librispeech/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/librispeech/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index 65194d902..b81e8dcfd 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/tiny/s1/local/align.sh b/examples/tiny/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/tiny/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index b148869b7..41f845b05 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi + diff --git a/tools/Makefile b/tools/Makefile index dd5902373..94e5ea2f7 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -19,7 +19,7 @@ kenlm.done: apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50 test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install - cd kenlm && python setup.py install + source venv/bin/activate; cd kenlm && python setup.py install touch kenlm.done sox.done: @@ -32,4 +32,4 @@ sox.done: soxbindings.done: test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git source venv/bin/activate; cd soxbindings && python setup.py install - touch soxbindings.done \ No newline at end of file + touch soxbindings.done