fix bug: bug of space id in score.cpp, add detokenize

pull/856/head
huangyuxin 3 years ago
parent db831ae702
commit 08281eca72

@ -26,6 +26,7 @@
#include "decoder_utils.h" #include "decoder_utils.h"
using namespace lm::ngram; using namespace lm::ngram;
const std::string kSPACE = "<space>";
Scorer::Scorer(double alpha, Scorer::Scorer(double alpha,
double beta, double beta,
@ -165,7 +166,7 @@ void Scorer::set_char_map(const std::vector<std::string>& char_list) {
// Set the char map for the FST for spelling correction // Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) { for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") { if (char_list_[i] == kSPACE) {
SPACE_ID_ = i; SPACE_ID_ = i;
} }
// The initial state of FST is state 0, hence the index of chars in // The initial state of FST is state 0, hence the index of chars in

@ -27,6 +27,7 @@ from paddle import inference
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
@ -271,6 +272,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
def ordid2token(self, texts, texts_len): def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
@ -299,6 +302,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
result_transcripts = self.compute_result_transcripts(audio, audio_len, result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg) vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts, for utt, target, result in zip(utts, target_transcripts,
result_transcripts): result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
@ -335,6 +339,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_prob=cfg.cutoff_prob, cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
#replace the <space> with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
for sentence in result_transcripts
]
self.autolog.times.stamp() self.autolog.times.stamp()
self.autolog.times.stamp() self.autolog.times.stamp()
self.autolog.times.end() self.autolog.times.end()
@ -455,6 +465,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_probs, output_lens, vocab_list, cfg.decoding_method, output_probs, output_lens, vocab_list, cfg.decoding_method,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
#replace the <space> with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
for sentence in result_transcripts
]
return result_transcripts return result_transcripts

Loading…
Cancel
Save