From f4b11b19e5cccc93116eb395446ca2b5140c5c40 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 11 Jul 2022 07:59:42 +0000 Subject: [PATCH] rename time_s and time_ns to time_b and time_nb --- .../server/engine/acs/python/acs_engine.py | 5 +- .../server/engine/asr/online/ctc_search.py | 50 +++++++++---------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/paddlespeech/server/engine/acs/python/acs_engine.py b/paddlespeech/server/engine/acs/python/acs_engine.py index 63964a825..a607aa07a 100644 --- a/paddlespeech/server/engine/acs/python/acs_engine.py +++ b/paddlespeech/server/engine/acs/python/acs_engine.py @@ -192,12 +192,15 @@ class ACSEngine(BaseEngine): # search for each word in self.word_list offset = self.config.offset + # last time in time_stamp max_ed = time_stamp[-1]['ed'] for w in self.word_list: # search the w in asr_result and the index in asr_result + # https://docs.python.org/3/library/re.html#re.finditer for m in re.finditer(w, asr_result): + # match start and end char index in timestamp + # https://docs.python.org/3/library/re.html#re.Match.start start = max(time_stamp[m.start(0)]['bg'] - offset, 0) - end = min(time_stamp[m.end(0) - 1]['ed'] + offset, max_ed) logger.debug(f'start: {start}, end: {end}') acs_result.append({'w': w, 'bg': start, 'ed': end}) diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 46f310c80..ad9647ef9 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -86,8 +86,8 @@ class CTCPrefixBeamSearch: # 2. viterbi_blank ending, # 3. viterbi_non_blank, # 4. current_token_prob, - # 5. times_viterbi_blank, - # 6. times_titerbi_non_blank + # 5. times_viterbi_blank, times_b + # 6. times_titerbi_non_blank, times_nb if self.cur_hyps is None: self.cur_hyps = [(tuple(), (0.0, -float('inf'), 0.0, 0.0, -float('inf'), [], []))] @@ -106,69 +106,69 @@ class CTCPrefixBeamSearch: for s in top_k_index: s = s.item() ps = logp[s].item() - for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_s, - times_ns) in self.cur_hyps: + for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_b, + times_nb) in self.cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == blank_id: # blank - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_b, n_times_nb = next_hyps[ prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) - pre_times = times_s if v_b_s > v_nb_s else times_ns - n_times_s = copy.deepcopy(pre_times) + pre_times = times_b if v_b_s > v_nb_s else times_nb + n_times_b = copy.deepcopy(pre_times) viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s n_v_s = viterbi_score + ps next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, - n_cur_token_prob, n_times_s, - n_times_ns) + n_cur_token_prob, n_times_b, + n_times_nb) elif s == last: # Update *ss -> *s; # case1: *a + a => *a - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_b, n_times_nb = next_hyps[ prefix] n_pnb = log_add([n_pnb, pnb + ps]) if n_v_ns < v_nb_s + ps: n_v_ns = v_nb_s + ps if n_cur_token_prob < ps: n_cur_token_prob = ps - n_times_ns = copy.deepcopy(times_ns) - n_times_ns[ + n_times_nb = copy.deepcopy(times_nb) + n_times_nb[ -1] = self.abs_time_step # 注意,这里要重新使用绝对时间 next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, - n_cur_token_prob, n_times_s, - n_times_ns) + n_cur_token_prob, n_times_b, + n_times_nb) # Update *s-s -> *ss, - is for blank # Case 2: *aε + a => *aa n_prefix = prefix + (s, ) - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_b, n_times_nb = next_hyps[ n_prefix] if n_v_ns < v_b_s + ps: n_v_ns = v_b_s + ps n_cur_token_prob = ps - n_times_ns = copy.deepcopy(times_s) - n_times_ns.append(self.abs_time_step) + n_times_nb = copy.deepcopy(times_b) + n_times_nb.append(self.abs_time_step) n_pnb = log_add([n_pnb, pb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, - n_cur_token_prob, n_times_s, - n_times_ns) + n_cur_token_prob, n_times_b, + n_times_nb) else: # Case 3: *a + b => *ab, *aε + b => *ab n_prefix = prefix + (s, ) - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_b, n_times_nb = next_hyps[ n_prefix] viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s - pre_times = times_s if v_b_s > v_nb_s else times_ns + pre_times = times_b if v_b_s > v_nb_s else times_nb if n_v_ns < viterbi_score + ps: n_v_ns = viterbi_score + ps n_cur_token_prob = ps - n_times_ns = copy.deepcopy(pre_times) - n_times_ns.append(self.abs_time_step) + n_times_nb = copy.deepcopy(pre_times) + n_times_nb.append(self.abs_time_step) n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, - n_cur_token_prob, n_times_s, - n_times_ns) + n_cur_token_prob, n_times_b, + n_times_nb) # 2.2 Second beam prune next_hyps = sorted(