From 3637fc0102c6b4ceaf8e0bb5e31af5d04dbd4d2a Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Thu, 7 Apr 2022 20:44:29 +0800 Subject: [PATCH] add aishell test script --- speechx/examples/aishell/local/compute-wer.py | 500 ++++++++++++++++++ speechx/examples/aishell/local/split_data.sh | 22 + speechx/examples/aishell/path.sh | 14 + speechx/examples/aishell/run.sh | 75 +++ speechx/examples/aishell/utils | 1 + .../offline_decoder_sliding_chunk_main.cc | 7 +- .../examples/feat/linear_spectrogram_main.cc | 3 +- 7 files changed, 619 insertions(+), 3 deletions(-) create mode 100755 speechx/examples/aishell/local/compute-wer.py create mode 100755 speechx/examples/aishell/local/split_data.sh create mode 100644 speechx/examples/aishell/path.sh create mode 100755 speechx/examples/aishell/run.sh create mode 120000 speechx/examples/aishell/utils diff --git a/speechx/examples/aishell/local/compute-wer.py b/speechx/examples/aishell/local/compute-wer.py new file mode 100755 index 000000000..a3eefc0dc --- /dev/null +++ b/speechx/examples/aishell/local/compute-wer.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import re, sys, unicodedata +import codecs + +remove_tag = True +spacelist= [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] + +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i+1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c==sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0; T=len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i-1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j-1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) + return result + def overall(self) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def cluster(self, data) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def keys(self) : + return list(self.data.keys()) + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + +def default_cluster(word) : + unicode_names = [ unicodedata.name(char) for char in word ] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names)-1) : + if unicode_names[i] != unicode_names[i+1] : + return 'Other' + return unicode_names[0] + +def usage() : + print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + +if __name__ == '__main__': + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose= 1 + padding_symbol= ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose=0 + try: + verbose=int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol= ' ' + elif b == 'underline': + padding_symbol= '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig=set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array)==0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array)==0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('WER: %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length-len_lab) + space['rec'].append(length-len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('lab:', end = ' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['lab'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('rec:', end = ' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['rec'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print('===========================================================================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('===========================================================================') diff --git a/speechx/examples/aishell/local/split_data.sh b/speechx/examples/aishell/local/split_data.sh new file mode 100755 index 000000000..b79c64d6b --- /dev/null +++ b/speechx/examples/aishell/local/split_data.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +data=$1 +feat_scp=$2 +numsplit=$3 + +if ! [ "$numsplit" -gt 0 ]; then + echo "Invalid num-split argument"; + exit 1; +fi + +directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done) +feat_split_scp=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/feats.scp; done) +echo $feat_split_scp +# if this mkdir fails due to argument-list being too long, iterate. +if ! mkdir -p $directories >&/dev/null; then + for n in `seq $numsplit`; do + mkdir -p $data/split${numsplit}/$n + done +fi + +utils/split_scp.pl $feat_scp $feat_split_scp diff --git a/speechx/examples/aishell/path.sh b/speechx/examples/aishell/path.sh new file mode 100644 index 000000000..a0e7c9aed --- /dev/null +++ b/speechx/examples/aishell/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../.. +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/aishell/run.sh b/speechx/examples/aishell/run.sh new file mode 100755 index 000000000..2ff25ae72 --- /dev/null +++ b/speechx/examples/aishell/run.sh @@ -0,0 +1,75 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + + +# 2. download model +if [ ! -d ../paddle_asr_model ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz + tar xzfv paddle_asr_model.tar.gz + mv ./paddle_asr_model ../ + # produce wav scp + echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp +fi + +mkdir -p data +if [ ! -d ./test ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + realpath ./test/*/*.wav > wavlist + awk -F '/' '{ print $(NF) }' wavlist | awk -F '.' '{ print $1 }' > utt_id + paste utt_id wavlist > aishell_test.scp +fi + +if [ ! -d aishell_ds2_online_model ]; then + mkdir -p aishell_ds2_online_model + wget -P ./aishell_ds2_online_model -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz + tar xzfv ./aishell_ds2_online_model/aishell_ds2_online_cer8.00_release.tar.gz -C ./aishell_ds2_online_model +fi + +# 3. make feature +aishell_wav_scp=./aishell_test.scp +aishell_online_model=./aishell_ds2_online_model/exp/deepspeech2_online/checkpoints +model_dir=../paddle_asr_model +feat_ark=./feats.ark +feat_scp=./aishell_feat.scp +cmvn=./cmvn.ark +label_file=./aishell_result +wer=./aishell_wer + +export GLOG_logtostderr=1 + +# 3. gen linear feat +linear_spectrogram_main \ + --wav_rspecifier=scp:$aishell_wav_scp \ + --feature_wspecifier=ark,scp:$feat_ark,$feat_scp \ + --cmvn_write_path=$cmvn \ + --streaming_chunk=10 + +nj=10 +data=./data +text=./test/text +# recognizer +./local/split_data.sh data aishell_feat.scp $nj + +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ + offline_decoder_sliding_chunk_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feats.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --dict_file=$model_dir/vocab.txt \ + --lm_path=$model_dir/avg_1.jit.klm \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result + +cat $data/split${nj}/*/result > $label_file + +local/compute-wer.py --char=1 --v=1 $label_file $text > $wer diff --git a/speechx/examples/aishell/utils b/speechx/examples/aishell/utils new file mode 120000 index 000000000..973afe674 --- /dev/null +++ b/speechx/examples/aishell/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index 7f6c572ca..1d7c722fc 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -22,7 +22,8 @@ #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" -DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); @@ -45,7 +46,8 @@ int main(int argc, char* argv[]) { google::InitGoogleLogging(argv[0]); kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_respecifier); + FLAGS_feature_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); std::string model_graph = FLAGS_model_path; std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; @@ -130,6 +132,7 @@ int main(int argc, char* argv[]) { std::string result; result = decoder.GetFinalBestPath(); KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); decodable->Reset(); decoder.Reset(); ++num_done; diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc index 2d75bb5df..7061d2b2d 100644 --- a/speechx/examples/feat/linear_spectrogram_main.cc +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -30,6 +30,7 @@ DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); std::vector mean_{ @@ -198,7 +199,7 @@ int main(int argc, char* argv[]) { LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; - float streaming_chunk = 0.36; + float streaming_chunk = FLAGS_streaming_chunk; int chunk_sample_size = streaming_chunk * sample_rate; LOG(INFO) << "sr: " << sample_rate; LOG(INFO) << "chunk size (s): " << streaming_chunk;