From 90788b116d85c26cf91bcb76544aaf5b2b189734 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 24 Jun 2021 04:05:34 +0000 Subject: [PATCH] more comment; fix datapipe of align --- deepspeech/exps/u2/model.py | 23 ++++++++++++++--------- deepspeech/utils/ctc_utils.py | 20 +++++++++++--------- deepspeech/utils/text_grid.py | 8 +++++--- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index f00d5af6..ba7bc45c 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -355,7 +355,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. + # 0: used for training, it's prohibited here. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. simulate_streaming=False, # simulate streaming inference. Defaults to False. )) @@ -512,11 +512,13 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Align Total Examples: {len(self.test_loader.dataset)}") - stride_ms = self.test_loader.dataset.stride_ms - token_dict = self.test_loader.dataset.vocab_list + stride_ms = self.test_loader.collate_fn.stride_ms + token_dict = self.test_loader.collate_fn.vocab_list with open(self.args.result_file, 'w') as fout: + # one example in batch for i, batch in enumerate(self.test_loader): key, feat, feats_length, target, target_length = batch + # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) @@ -529,28 +531,31 @@ class U2Tester(U2Trainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - print(alignment) + print(kye[0], alignment) fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - print(align_segs) + print(kye[0], align_segs) # IntervalTier, List["start end token\n"] subsample = get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) + # write tier tier_path = os.path.join( os.path.dirname(args.result_file), key[0] + ".tier") with open(tier_path, 'w') as f: f.writelines(tierformat) - + # write textgrid textgrid_path = s.path.join( os.path.dirname(args.result_file), key[0] + ".TextGrid") - second_per_frame = 1. / (1000. / stride_ms - ) # 25ms window, 10ms stride + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame text_grid.generate_textgrid( - maxtime=(len(alignment) + 1) * subsample * second_per_frame, + maxtime=second_per_example, lines=tierformat, output=textgrid_path) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 76c1898b..6201233d 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -38,8 +38,10 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): + # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) + # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 @@ -52,7 +54,7 @@ def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: "abcdefg" -> "-a-b-c-d-e-f-g-" Args: - label ([np.ndarray]): label ids, (L). + label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: @@ -61,8 +63,8 @@ def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] - label = label.reshape(-1) #[2L] - label = np.append(label, label[0]) #[2L + 1] + label = label.reshape(-1) #[2L], -l-l-l + label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label @@ -79,21 +81,21 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, Returns: List[int]: best alignment result, (T). """ - y_insert_blank = insert_blank(y, blank_id) + y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero state_path = (paddle.zeros( (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 - ) # state path + ) # state path, Tuple((T, 2L+1)) # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): + for t in range(1, ctc_probs.size(0)): # T + for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( diff --git a/deepspeech/utils/text_grid.py b/deepspeech/utils/text_grid.py index 9afed89e..b774130d 100644 --- a/deepspeech/utils/text_grid.py +++ b/deepspeech/utils/text_grid.py @@ -22,11 +22,13 @@ def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]: """segment ctc alignment ids by continuous blank and repeat label. Args: - alignment (List[int]): ctc alignment id sequence. e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] + alignment (List[int]): ctc alignment id sequence. + e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] blank_id (int, optional): blank id. Defaults to 0. Returns: - List[List[int]]: segment aligment id sequence. e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] + List[List[int]]: token align, segment aligment id sequence. + e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] """ # convert alignment to a praat format, which is a doing phonetics # by computer and helps analyzing alignment @@ -61,7 +63,7 @@ def align_to_tierformat(align_segs: List[List[int]], token_dict (Dict[int, Text]): int -> str map. Returns: - List[Text]: list of textgrid.Interval. + List[Text]: list of textgrid.Interval text, str(start, end, text). """ hop_length = 10 # ms second_ms = 1000 # ms