From f082fcbbec9861198e921df4448b5fa674e02556 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 5 May 2022 23:22:24 +0800 Subject: [PATCH] update the time stamp type, test=doc --- .../server/engine/asr/online/asr_engine.py | 18 ++++++++++--- .../server/engine/asr/online/ctc_search.py | 25 +++++++++---------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 58cd3488..427e7e36 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -296,6 +296,8 @@ class PaddleASRConnectionHanddler: self.chunk_num = 0 self.global_frame_offset = 0 self.result_transcripts = [''] + self.word_time_stamp = [] + self.time_stamp = [] self.first_char_occur_elapsed = None self.word_time_stamp = None @@ -514,10 +516,7 @@ class PaddleASRConnectionHanddler: return '' def get_word_time_stamp(self): - if self.word_time_stamp is None: - return [] - else: - return self.word_time_stamp + return self.word_time_stamp @paddle.no_grad() def rescoring(self): @@ -581,7 +580,18 @@ class PaddleASRConnectionHanddler: best_index = i # update the one best result + # hyps stored the beam results and each fields is: + logger.info(f"best index: {best_index}") + # logger.info(f'best result: {hyps[best_index]}') + # the field of the hyps is: + # hyps[0][0]: the sentence word-id in the vocab with a tuple + # hyps[0][1]: the sentence decoding probability with all paths + # hyps[0][2]: viterbi_blank ending probability + # hyps[0][3]: viterbi_non_blank probability + # hyps[0][4]: current_token_prob, + # hyps[0][5]: times_viterbi_blank, + # hyps[0][6]: times_titerbi_non_blank self.hyps = [hyps[best_index][0]] # update the hyps time stamp diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 3a808587..4c9ac3ac 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -27,7 +27,7 @@ class CTCPrefixBeamSearch: """Implement the ctc prefix beam search Args: - config (yacs.config.CfgNode): _description_ + config (yacs.config.CfgNode): the ctc prefix beam search configuration """ self.config = config self.reset() @@ -69,7 +69,6 @@ class CTCPrefixBeamSearch: # 2. CTC beam search step by step for t in range(0, maxlen): logp = ctc_probs[t] # (vocab_size,) - # key: prefix, value (pb, pnb), default value(-inf, -inf) # next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) next_hyps = defaultdict( lambda: (-float('inf'), -float('inf'), -float('inf'), -float('inf'), -float('inf'), [], [])) @@ -80,7 +79,7 @@ class CTCPrefixBeamSearch: for s in top_k_index: s = s.item() ps = logp[s].item() - for prefix, (pb, pnb, v_s, v_ns, cur_token_prob, times_s, + for prefix, (pb, pnb, v_b_s, v_nb_s, cur_token_prob, times_s, times_ns) in self.cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == blank_id: # blank @@ -88,9 +87,9 @@ class CTCPrefixBeamSearch: prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) - pre_times = times_s if v_s > v_ns else times_ns + pre_times = times_s if v_b_s > v_nb_s else times_ns n_times_s = copy.deepcopy(pre_times) - viterbi_score = v_s if v_s > v_ns else v_ns + viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s n_v_s = viterbi_score + ps next_hyps[prefix] = (n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, @@ -101,8 +100,8 @@ class CTCPrefixBeamSearch: n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ prefix] n_pnb = log_add([n_pnb, pnb + ps]) - if n_v_ns < v_ns + ps: - n_v_ns = v_ns + ps + if n_v_ns < v_nb_s + ps: + n_v_ns = v_nb_s + ps if n_cur_token_prob < ps: n_cur_token_prob = ps n_times_ns = copy.deepcopy(times_ns) @@ -117,8 +116,8 @@ class CTCPrefixBeamSearch: n_prefix = prefix + (s, ) n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_prefix] - if n_v_ns < v_s + ps: - n_v_ns = v_s + ps + if n_v_ns < v_b_s + ps: + n_v_ns = v_b_s + ps n_cur_token_prob = ps n_times_ns = copy.deepcopy(times_s) n_times_ns.append(self.abs_time_step) @@ -129,10 +128,10 @@ class CTCPrefixBeamSearch: else: # Case 3: *a + b => *ab, *aε + b => *ab n_prefix = prefix + (s, ) - n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_n = next_hyps[ + n_pb, n_pnb, n_v_s, n_v_ns, n_cur_token_prob, n_times_s, n_times_ns = next_hyps[ n_prefix] - viterbi_score = v_s if v_s > v_ns else v_ns - pre_times = times_s if v_s > v_ns else times_ns + viterbi_score = v_b_s if v_b_s > v_nb_s else v_nb_s + pre_times = times_s if v_b_s > v_nb_s else times_ns if n_v_ns < viterbi_score + ps: n_v_ns = viterbi_score + ps n_cur_token_prob = ps @@ -153,7 +152,7 @@ class CTCPrefixBeamSearch: # 2.3 update the absolute time step self.abs_time_step += 1 - # self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + self.hyps = [(y[0], log_add([y[1][0], y[1][1]]), y[1][2], y[1][3], y[1][4], y[1][5], y[1][6]) for y in self.cur_hyps]