Merge pull request #1707 from zh794390558/spx

[speechx] refactor speech egs
pull/1709/head
Hui Zhang 3 years ago committed by GitHub
commit 91e255ceaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -20,6 +20,7 @@ from diskcache import Cache
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from logs import LOGGER
from milvus_helpers import MilvusHelper
from mysql_helpers import MySQLHelper
from operations.count import do_count
@ -31,8 +32,6 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI()
app.add_middleware(
CORSMiddleware,

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from logs import LOGGER
from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor()

@ -20,7 +20,6 @@ from config import MYSQL_HOST
from config import MYSQL_PORT
from config import MYSQL_PWD
from config import MYSQL_USER
from logs import LOGGER

@ -14,7 +14,6 @@
import sys
from config import DEFAULT_TABLE
from logs import LOGGER

@ -14,7 +14,6 @@
import sys
from config import DEFAULT_TABLE
from logs import LOGGER

@ -17,7 +17,6 @@ import sys
from config import DEFAULT_TABLE
from diskcache import Cache
from encode import get_audio_embedding
from logs import LOGGER
@ -27,9 +26,8 @@ def get_audios(path):
"""
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [
item
for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
item for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats
]

@ -17,7 +17,6 @@ import numpy
from config import DEFAULT_TABLE
from config import TOP_K
from encode import get_audio_embedding
from logs import LOGGER

@ -18,6 +18,7 @@ from config import UPLOAD_PATH
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from logs import LOGGER
from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr
from operations.count import do_get
@ -30,8 +31,6 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI()
app.add_middleware(
CORSMiddleware,

@ -84,7 +84,7 @@ setuptools.setup(
install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos'
],
],
extras_require={
'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',

@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import base64
import io
import json
import logging
import os
import random
import time
from typing import List
import logging
import asyncio
import numpy as np
import requests
@ -30,9 +30,9 @@ from ..executor import BaseExecutor
from ..util import cli_client_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
@ -234,7 +234,8 @@ class ASRClientExecutor(BaseExecutor):
@cli_client_register(
name='paddlespeech_client.asr_online', description='visit asr online service')
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASRClientExecutor(BaseExecutor):
def __init__(self):
super(ASRClientExecutor, self).__init__()

@ -1,12 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse
from flask import Flask, render_template
from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id')
@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True)

@ -15,4 +15,4 @@
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)

@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio

@ -1,7 +1,4 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
add_subdirectory(glog)
add_subdirectory(ds2_ol)
add_subdirectory(dev)

@ -1,17 +1,25 @@
# Examples
# Examples for SpeechX
* dev - for speechx developer, using for test.
* ngram - using to build NGram ARPA lm.
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
The entrypoint is `ds2_ol/aishell/run.sh`
* glog - glog usage
* feat - mfcc, linear
* nnet - ds2 nn
* decoder - online decoder to work as offline
## How to run
`run.sh` is the entry point.
Example to play `decoder`:
Example to play `ds2_ol`:
```
pushd decoder
pushd ds2_ol/aishell
bash run.sh
```
## Display Model with [Netron](https://github.com/lutzroeder/netron)
```
pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
```

@ -1,500 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re, sys, unicodedata
import codecs
remove_tag = True
spacelist= [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'', '', '', '', '', '',
'', '', '', '', '', '', '', '']
def characterize(string) :
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c==sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator :
def __init__(self) :
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab) :
self.space.append([])
for row in self.space :
for element in row :
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec) :
row.append({'dist' : 0, 'error' : 'non'})
for i in range(len(lab)) :
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)) :
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
if i == 0 or j == 0 :
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i-1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist :
min_dist = dist
min_error = error
dist = self.space[i][j-1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist :
min_dist = dist
min_error = error
if lab_token == rec_token :
dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
error = 'cor'
else :
dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist :
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True :
if self.space[i][j]['error'] == 'cor' : # correct
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub' : # substitution
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del' : # deletion
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins' : # insertion
if len(rec[j]) > 0 :
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non' : # starting point
break
else : # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
if token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self) :
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
if unicode_names[i].startswith('DIGIT') : # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')) :
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
unicode_names[i].startswith('COMMERCIAL AT') or
unicode_names[i].startswith('DEGREE CELSIUS') or
unicode_names[i].startswith('EQUALS SIGN') or
unicode_names[i].startswith('FULL STOP') or
unicode_names[i].startswith('HYPHEN-MINUS') or
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')) :
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else :
return 'Other'
if len(unicode_names) == 0 :
return 'Other'
if len(unicode_names) == 1 :
return unicode_names[0]
for i in range(len(unicode_names)-1) :
if unicode_names[i] != unicode_names[i+1] :
return 'Other'
return unicode_names[0]
def usage() :
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
if __name__ == '__main__':
if len(sys.argv) == 1 :
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose= 1
padding_symbol= ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose=0
try:
verbose=int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol= ' '
elif b == 'underline':
padding_symbol= '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig=set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array)==0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') :
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab :
if word not in default_words :
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters :
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name] :
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('WER: %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])) :
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length-len_lab)
space['rec'].append(length-len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('lab:', end = ' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['lab'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('rec:', end = ' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['rec'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print('===========================================================================')
print()
result = calculator.overall()
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('Overall -> %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters :
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if len(cluster_file) > 0 : # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8') :
for token in line.decode('utf-8').rstrip('\n').split() :
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else :
cluster.append(token)
print()
print('===========================================================================')

@ -1,18 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -1,121 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(chunk_size, 35, "feat chunk size");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding speech feature, deprecated.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
int32 chunk_size = FLAGS_chunk_size;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
LOG(INFO) << "chunk size (frame): " << chunk_size;
int32 num_done = 0, num_err = 0;
// frontend + nnet is decodable
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
// init decoder
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
LOG(INFO) << "Init decoder.";
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
LOG(INFO) << "utt: " << utt;
// feat dim
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "dim: " << raw_data->Dim();
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
LOG(INFO) << "n chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
// feat chunk
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(feat_one_row);
row_idx++;
}
// feed to raw cache
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// decode step
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,43 +0,0 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
export GLOG_logtostderr=1
# 3. gen linear feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
# 4. run decoder
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm

@ -0,0 +1,3 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog)

@ -1,14 +1,15 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_ROOT=$PWD/../../../
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
export LC_AL=C

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)

@ -0,0 +1,11 @@
# Deepspeech2 Streaming
Please go to `aishell` to test it.
* aishell
Deepspeech2 Streaming Decoding under aishell dataset.
The below is for developing and offline testing:
* nnet
* feat
* decoder

@ -0,0 +1,3 @@
data
exp
aishell_*

@ -0,0 +1,21 @@
# Aishell - Deepspeech2 Streaming
## CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC Prefix Beam Search w LM
```
```
## CTC WFST
```
```

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -4,6 +4,9 @@ set -e
. path.sh
nj=40
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip -d $data aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
model_dir=$PWD/aishell_ds2_online_model
if [ ! -d $model_dir ]; then
mkdir -p $model_dir
wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
if [ ! -f $lm ]; then
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
# 3. make feature
aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
lm_model_dir=../paddle_asr_model
label_file=./aishell_result
wer=./aishell_wer
nj=40
export GLOG_logtostderr=1
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data
# 3. gen linear feat
cmvn=$PWD/cmvn.ark
cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
linear_spectrogram_without_db_norm_main \
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \
@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \
text=$data/test/text
# 4. recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
offline_decoder_sliding_chunk_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
utils/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
offline_decoder_sliding_chunk_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph
if [ ! -d $ ]; then
@ -97,17 +109,19 @@ if [ ! -d $ ]; then
unzip -d aishell_graph.zip
fi
# 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
offline_wfst_decoder_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
utils/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg

@ -0,0 +1,19 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ctc-prefix-beam-search-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
set(bin_name wfst-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
set(bin_name nnet-logprob-decoder-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -0,0 +1,12 @@
# ASR Decoder
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
* decoder_test_main.cc
feed nnet output logprob, and only test decoder
* offline_decoder_sliding_chunk_main.cc
feed streaming audio feature, decode as streaming manner.
* offline_wfst_decoder_main.cc
feed streaming audio feature, decode using WFST as streaming manner.

@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -50,9 +52,13 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
@ -73,6 +79,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,79 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
# output
exp_dir=./exp
mkdir -p $exp_dir
# 2. download model
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# download lm
if [ ! -f $lm ]; then
pushd data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark
export GLOG_logtostderr=1
# dump json cmvn to kaldi
cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
# generate linear feature as streaming
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming
ctc-prefix-beam-search-decoder-ol \
--result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm

@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,

@ -0,0 +1,12 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name linear-spectrogram-wo-db-norm-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)

@ -0,0 +1,7 @@
# Deepspeech2 Straming Audio Feature
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
* linear_spectrogram_without_db_norm_main.cc
compute linear spectrogram w/o db norm in streaming manner.

@ -0,0 +1,81 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/simdjson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
try {
padded_string json = padded_string::load(FLAGS_json_file);
ondemand::parser parser;
ondemand::document doc = parser.iterate(json);
ondemand::value val = doc;
ondemand::array mean_stat = val["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) {
mean_stat_vec.push_back(x);
}
// LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated
// when
// they are first encountered")
ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
}
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
kaldi::WriteKaldiObject(
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
}
return 0;
}

@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,57 @@
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# input
data_dir=./data
exp_dir=./exp
model_dir=$data_dir/model/
mkdir -p $exp_dir
# 3. run feat
export GLOG_logtostderr=1
cmvn-json2kaldi \
--json_file $model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."

@ -0,0 +1,6 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2-model-ol-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})

@ -0,0 +1,3 @@
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.

@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// deepspeech2 online model info
#include <algorithm>
#include <fstream>
#include <functional>
@ -20,21 +21,26 @@
#include <iterator>
#include <numeric>
#include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
DEFINE_string(model_path, "", "xxx.pdmodel");
DEFINE_string(param_path, "", "xxx.pdiparams");
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
DEFINE_int32(feat_dim, 161, "feature dim");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
@ -57,6 +63,8 @@ void model_forward_test() {
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
CHECK(model_graph != "");
CHECK(model_params != "");
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
@ -106,7 +114,7 @@ void model_forward_test() {
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
@ -119,7 +127,7 @@ void model_forward_test() {
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
@ -187,7 +195,9 @@ void model_forward_test() {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
model_forward_test();
return 0;
}

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/glog
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,38 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
ds2-model-ol-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams

@ -1,18 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)

@ -1,58 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/simdjson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ondemand::parser parser;
padded_string json = padded_string::load(FLAGS_json_file);
ondemand::document val = parser.iterate(json);
ondemand::object doc = val;
kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
auto mean_stat = doc["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) {
mean_stat_vec.push_back(x);
}
auto var_stat = doc["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
}
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
return 0;
}

@ -1,719 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -1,270 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, false);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
WriteMatrix();
int32 num_done = 0, num_err = 0;
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
opt.streaming_chunk = FLAGS_streaming_chunk;
opt.frame_opts.dither = 0.0;
opt.frame_opts.remove_dc_offset = false;
opt.frame_opts.window_type = "hanning";
opt.frame_opts.preemph_coeff = 0.0;
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
feature_cache.Read(&features);
if (features.Dim() == 0) break;
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.Dim() / feature_cache.Dim();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feature_cache.Dim());
for (auto feat : feats) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,32 +0,0 @@
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
# 3. run feat
export GLOG_logtostderr=1
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn

@ -0,0 +1,2 @@
data
exp

@ -1,5 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})

@ -1,29 +0,0 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
# 4. run decoder
pp-model-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams

@ -92,8 +92,7 @@ void CTCBeamSearch::AdvanceDecode(
while (1) {
vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob;
bool flag =
decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);

@ -46,10 +46,10 @@ class LinearSpectrogram : public FrontendInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
virtual void Reset() {
base_extractor_->Reset();
reminded_wav_.Resize(0);
}
}
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,

@ -49,19 +49,19 @@ bool Decodable::IsLastFrame(int32 frame) {
int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob id.
int32 Decodable::TokenId2NnetId(int32 token_id) {
return token_id - 1;
}
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
// id.
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
// the index - 1, because the ilabel
return acoustic_scale_ *
std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) {

@ -37,8 +37,7 @@ std::string ReadFile2String(const std::string& path) {
if (!input_file.is_open()) {
std::cerr << "please input a valid file" << std::endl;
}
return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>());
return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>());
}
}

@ -20,5 +20,4 @@ bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
std::string ReadFile2String(const std::string& path);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -35,66 +35,68 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n"
"They are added on each final state and each state with non-epsilon "
"output symbols\n"
"on at least one arc out of the state. Useful in conjunction with "
"predeterminize\n"
"\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst "
"[out.fst] ]\n"
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
"same length.\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n"
"They are added on each final state and each state with "
"non-epsilon "
"output symbols\n"
"on at least one arc out of the state. Useful in conjunction with "
"predeterminize\n"
"\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list "
"[in.fst "
"[out.fst] ]\n"
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
"same length.\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR << "fstaddselfloops: mismatch in size of disambiguation "
"symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}

@ -56,59 +56,61 @@ bool debug_location = false;
void signal_handler(int) { debug_location = true; }
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Removes epsilons and determinizes in one step\n"
"\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n"
"See also: fstdeterminizelog, lattice-determinize\n";
const char *usage =
"Removes epsilons and determinizes in one step\n"
"\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n"
"See also: fstdeterminizelog, lattice-determinize\n";
float delta = kDelta;
int max_states = -1;
bool use_log = false;
ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta", &delta,
"Delta value used to determine equivalence of weights.");
po.Register(
"max-states", &max_states,
"Maximum number of states in determinized FST before it will abort.");
po.Read(argc, argv);
float delta = kDelta;
int max_states = -1;
bool use_log = false;
ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta",
&delta,
"Delta value used to determine equivalence of weights.");
po.Register("max-states",
&max_states,
"Maximum number of states in determinized FST before it "
"will abort.");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
// This enables us to get traceback info from determinization that is
// not seeming to terminate.
// This enables us to get traceback info from determinization that is
// not seeming to terminate.
#if !defined(_MSC_VER) && !defined(__APPLE__)
signal(SIGUSR1, signal_handler);
signal(SIGUSR1, signal_handler);
#endif
// Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
// Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway.
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway.
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -42,50 +42,51 @@
// though not stochastic because we gave it an absurdly large delta.
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Checks whether an FST is stochastic and exits with success if so.\n"
"Prints out maximum error (in log units).\n"
"\n"
"Usage: fstisstochastic [ in.fst ]\n";
const char *usage =
"Checks whether an FST is stochastic and exits with success if "
"so.\n"
"Prints out maximum error (in log units).\n"
"\n"
"Usage: fstisstochastic [ in.fst ]\n";
float delta = 0.01;
bool test_in_log = true;
float delta = 0.01;
bool test_in_log = true;
ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept.");
po.Register("test-in-log", &test_in_log,
"Test stochasticity in log semiring.");
po.Read(argc, argv);
ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept.");
po.Register(
"test-in-log", &test_in_log, "Test stochasticity in log semiring.");
po.Read(argc, argv);
if (po.NumArgs() > 1) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 1) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1);
std::string fst_in_filename = po.GetOptArg(1);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
bool ans;
StdArc::Weight min, max;
if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else
ans = IsStochasticFst(*fst, delta, &min, &max);
bool ans;
StdArc::Weight min, max;
if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else
ans = IsStochasticFst(*fst, delta, &min, &max);
std::cout << min.Value() << " " << max.Value() << '\n';
delete fst;
if (ans)
return 0; // success;
else
return 1;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::cout << min.Value() << " " << max.Value() << '\n';
delete fst;
if (ans)
return 0; // success;
else
return 1;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -33,42 +33,43 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n"
"\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n"
"\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
float delta = kDelta;
ParseOptions po(usage);
po.Register("delta", &delta,
"Delta likelihood used for quantization of weights");
po.Read(argc, argv);
float delta = kDelta;
ParseOptions po(usage);
po.Register("delta",
&delta,
"Delta likelihood used for quantization of weights");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2);
std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
MinimizeEncoded(fst, delta);
MinimizeEncoded(fst, delta);
WriteFstKaldi(*fst, fst_out_filename);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}

@ -37,97 +37,104 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
/*
fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST
is possible.
*/
const char *usage =
"Composition algorithm [between two FSTs of standard type, in "
"tropical\n"
"semiring] that is more efficient for certain cases-- in particular,\n"
"where one of the FSTs (the left one, if --match-side=left) has large\n"
"out-degree\n"
"\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n";
ParseOptions po(usage);
TableComposeOptions opts;
std::string match_side = "left";
std::string compose_filter = "sequence";
po.Register("connect", &opts.connect, "If true, trim FST before output.");
po.Register("match-side", &match_side,
"Side of composition to do table "
"match, one of: \"left\" or \"right\".");
po.Register("compose-filter", &compose_filter,
"Composition filter to use, "
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
/*
fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST
is possible.
*/
const char *usage =
"Composition algorithm [between two FSTs of standard type, in "
"tropical\n"
"semiring] that is more efficient for certain cases-- in "
"particular,\n"
"where one of the FSTs (the left one, if --match-side=left) has "
"large\n"
"out-degree\n"
"\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
"(fst2-rxfilename|fst2-rspecifier) "
"[(out-rxfilename|out-rspecifier)]\n";
ParseOptions po(usage);
TableComposeOptions opts;
std::string match_side = "left";
std::string compose_filter = "sequence";
po.Register(
"connect", &opts.connect, "If true, trim FST before output.");
po.Register("match-side",
&match_side,
"Side of composition to do table "
"match, one of: \"left\" or \"right\".");
po.Register(
"compose-filter",
&compose_filter,
"Composition filter to use, "
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -24,122 +24,130 @@
#include "util/parse-options.h"
int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n"
"an embedded symbol table. This is compatible with the way a previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table", &read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table", &write_syms_filename,
"Write generated symbol table to a file");
po.Register("keep-symbols", &keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will "
"contain\n"
"an embedded symbol table. This is compatible with the way a "
"previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register(
"disambig-symbol",
&disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table",
&read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table",
&write_syms_filename,
"Write generated symbol table to a file");
po.Register(
"keep-symbols",
&keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

Loading…
Cancel
Save