pull/1707/head
Hui Zhang 3 years ago
parent cad09b4910
commit c7b987c55d

@ -20,6 +20,7 @@ from diskcache import Cache
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import File from fastapi import File
from fastapi import UploadFile from fastapi import UploadFile
from logs import LOGGER
from milvus_helpers import MilvusHelper from milvus_helpers import MilvusHelper
from mysql_helpers import MySQLHelper from mysql_helpers import MySQLHelper
from operations.count import do_count from operations.count import do_count
@ -31,8 +32,6 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import FileResponse from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from logs import LOGGER from logs import LOGGER
from paddlespeech.cli import VectorExecutor from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()

@ -20,7 +20,6 @@ from config import MYSQL_HOST
from config import MYSQL_PORT from config import MYSQL_PORT
from config import MYSQL_PWD from config import MYSQL_PWD
from config import MYSQL_USER from config import MYSQL_USER
from logs import LOGGER from logs import LOGGER

@ -14,7 +14,6 @@
import sys import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from logs import LOGGER from logs import LOGGER

@ -14,7 +14,6 @@
import sys import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from logs import LOGGER from logs import LOGGER

@ -17,7 +17,6 @@ import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from diskcache import Cache from diskcache import Cache
from encode import get_audio_embedding from encode import get_audio_embedding
from logs import LOGGER from logs import LOGGER
@ -27,9 +26,8 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item item for sublist in [[os.path.join(dir, file) for file in files]
for sublist in [[os.path.join(dir, file) for file in files] for dir, _, files in list(os.walk(path))]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]

@ -17,7 +17,6 @@ import numpy
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from config import TOP_K from config import TOP_K
from encode import get_audio_embedding from encode import get_audio_embedding
from logs import LOGGER from logs import LOGGER

@ -18,6 +18,7 @@ from config import UPLOAD_PATH
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import File from fastapi import File
from fastapi import UploadFile from fastapi import UploadFile
from logs import LOGGER
from mysql_helpers import MySQLHelper from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr from operations.count import do_count_vpr
from operations.count import do_get from operations.count import do_get
@ -30,8 +31,6 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import FileResponse from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,

@ -84,7 +84,7 @@ setuptools.setup(
install_requires=[ install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2', 'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos' 'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos'
], ],
extras_require={ extras_require={
'test': [ 'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',

@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import asyncio
import base64 import base64
import io import io
import json import json
import logging
import os import os
import random import random
import time import time
from typing import List from typing import List
import logging
import asyncio
import numpy as np import numpy as np
import requests import requests
@ -30,9 +30,9 @@ from ..executor import BaseExecutor
from ..util import cli_client_register from ..util import cli_client_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64 from paddlespeech.server.utils.util import wav2base64
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor'] __all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
@ -234,7 +234,8 @@ class ASRClientExecutor(BaseExecutor):
@cli_client_register( @cli_client_register(
name='paddlespeech_client.asr_online', description='visit asr online service') name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASRClientExecutor(BaseExecutor): class ASRClientExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super(ASRClientExecutor, self).__init__() super(ASRClientExecutor, self).__init__()

@ -1,12 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng) # Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse import argparse
from flask import Flask, render_template from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network') parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id') parser.add_argument('--port', default=19999, type=int, help='port id')
@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True) app.run(host='0.0.0.0', port=args.port, debug=True)

@ -15,4 +15,4 @@
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。 在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png) ![图片](./paddle_web_demo.png)

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio

@ -8,4 +8,4 @@ Deepspeech2 Streaming Decoding under aishell dataset.
The below is for developing and offline testing: The below is for developing and offline testing:
* nnet * nnet
* feat * feat
* decoder * decoder

@ -9,4 +9,4 @@ feed nnet output logprob, and only test decoder
feed streaming audio feature, decode as streaming manner. feed streaming audio feature, decode as streaming manner.
* offline_wfst_decoder_main.cc * offline_wfst_decoder_main.cc
feed streaming audio feature, decode using WFST as streaming manner. feed streaming audio feature, decode using WFST as streaming manner.

@ -34,9 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_input_names, DEFINE_string(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", model_input_names,
"model input names"); "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names"); "model output names");
@ -57,7 +58,7 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;

@ -5,4 +5,3 @@ ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr
* linear_spectrogram_without_db_norm_main.cc * linear_spectrogram_without_db_norm_main.cc
compute linear spectrogram w/o db norm in streaming manner. compute linear spectrogram w/o db norm in streaming manner.

@ -30,8 +30,8 @@ using namespace simdjson;
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file ; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
padded_string json = padded_string::load(FLAGS_json_file); padded_string json = padded_string::load(FLAGS_json_file);
ondemand::parser parser; ondemand::parser parser;
@ -43,9 +43,11 @@ int main(int argc, char* argv[]) {
for (double x : mean_stat) { for (double x : mean_stat) {
mean_stat_vec.push_back(x); mean_stat_vec.push_back(x);
} }
// LOG(INFO) << mean_stat; this line will casue simdjson::simdjson_error("Objects and arrays can only be iterated when they are first encountered") // LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated when
// they are first encountered")
ondemand::array var_stat = val["var_stat"]; ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec; std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) { for (double x : var_stat) {
var_stat_vec.push_back(x); var_stat_vec.push_back(x);
@ -53,7 +55,7 @@ int main(int argc, char* argv[]) {
kaldi::int32 frame_num = uint64_t(val["frame_num"]); kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num; LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size(); size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1); kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) { for (size_t idx = 0; idx < mean_size; ++idx) {

@ -1,3 +1,3 @@
# Deepspeech2 Streaming NNet Test # Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test. Using for ds2 streaming nnet inference test.

@ -14,8 +14,6 @@
// deepspeech2 online model info // deepspeech2 online model info
#include "base/flags.h"
#include "base/log.h"
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <functional> #include <functional>
@ -23,6 +21,8 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
using std::cout; using std::cout;
@ -40,7 +40,7 @@ void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) { void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = FLAGS_chunk_size; // chunk_size in frame int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = FLAGS_feat_dim; // feat dim int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl; cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl; cout << "feat dim: " << col_size << endl;
@ -197,7 +197,7 @@ void model_forward_test() {
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
model_forward_test(); model_forward_test();
return 0; return 0;
} }

@ -1,3 +1 @@
# NGram Train # NGram Train

@ -92,8 +92,7 @@ void CTCBeamSearch::AdvanceDecode(
while (1) { while (1) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob; vector<BaseFloat> frame_prob;
bool flag = bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break; if (flag == false) break;
likelihood.push_back(frame_prob); likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);

@ -46,10 +46,10 @@ class LinearSpectrogram : public FrontendInterface {
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); } virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { virtual void Reset() {
base_extractor_->Reset(); base_extractor_->Reset();
reminded_wav_.Resize(0); reminded_wav_.Resize(0);
} }
private: private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves, bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,

@ -49,19 +49,19 @@ bool Decodable::IsLastFrame(int32 frame) {
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob id. // the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
int32 Decodable::TokenId2NnetId(int32 token_id) { // id.
return token_id - 1; int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob // the nnet output is prob ranther than log prob
// the index - 1, because the ilabel // the index - 1, because the ilabel
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) + return acoustic_scale_ *
std::numeric_limits<float>::min()); std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
} }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {

@ -45,7 +45,8 @@ struct ModelOptions {
thread_num(2), thread_num(2),
use_gpu(false), use_gpu(false),
input_names( input_names(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"),
output_names( output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/" "save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"

@ -37,8 +37,7 @@ std::string ReadFile2String(const std::string& path) {
if (!input_file.is_open()) { if (!input_file.is_open()) {
std::cerr << "please input a valid file" << std::endl; std::cerr << "please input a valid file" << std::endl;
} }
return std::string((std::istreambuf_iterator<char>(input_file)), return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
} }
} }

@ -20,5 +20,4 @@ bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data); std::vector<std::string>* data);
std::string ReadFile2String(const std::string& path); std::string ReadFile2String(const std::string& path);
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// fstbin/fstaddselfloops.cc // fstbin/fstaddselfloops.cc
// Copyright 2009-2011 Microsoft Corporation // Copyright 2009-2011 Microsoft Corporation
@ -35,66 +49,68 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation " "Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n" "symbols through it\n"
"They are added on each final state and each state with non-epsilon " "They are added on each final state and each state with "
"output symbols\n" "non-epsilon "
"on at least one arc out of the state. Useful in conjunction with " "output symbols\n"
"predeterminize\n" "on at least one arc out of the state. Useful in conjunction with "
"\n" "predeterminize\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst " "\n"
"[out.fst] ]\n" "Usage: fstaddselfloops in-disambig-list out-disambig-list "
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n" "[in.fst "
"in.list and out.list are lists of integers, one per line, of the\n" "[out.fst] ]\n"
"same length.\n"; "E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
ParseOptions po(usage); "same length.\n";
po.Read(argc, argv);
ParseOptions po(usage);
if (po.NumArgs() < 2 || po.NumArgs() > 4) { po.Read(argc, argv);
po.PrintUsage();
exit(1); if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR << "fstaddselfloops: mismatch in size of disambiguation "
"symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0; return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// fstbin/fstdeterminizestar.cc // fstbin/fstdeterminizestar.cc
// Copyright 2009-2011 Microsoft Corporation // Copyright 2009-2011 Microsoft Corporation
@ -56,59 +70,61 @@ bool debug_location = false;
void signal_handler(int) { debug_location = true; } void signal_handler(int) { debug_location = true; }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Removes epsilons and determinizes in one step\n" "Removes epsilons and determinizes in one step\n"
"\n" "\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n" "Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n" "\n"
"See also: fstdeterminizelog, lattice-determinize\n"; "See also: fstdeterminizelog, lattice-determinize\n";
float delta = kDelta; float delta = kDelta;
int max_states = -1; int max_states = -1;
bool use_log = false; bool use_log = false;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring."); po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta", &delta, po.Register("delta",
"Delta value used to determine equivalence of weights."); &delta,
po.Register( "Delta value used to determine equivalence of weights.");
"max-states", &max_states, po.Register("max-states",
"Maximum number of states in determinized FST before it will abort."); &max_states,
po.Read(argc, argv); "Maximum number of states in determinized FST before it "
"will abort.");
po.Read(argc, argv);
if (po.NumArgs() > 2) { if (po.NumArgs() > 2) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2); std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
// This enables us to get traceback info from determinization that is // This enables us to get traceback info from determinization that is
// not seeming to terminate. // not seeming to terminate.
#if !defined(_MSC_VER) && !defined(__APPLE__) #if !defined(_MSC_VER) && !defined(__APPLE__)
signal(SIGUSR1, signal_handler); signal(SIGUSR1, signal_handler);
#endif #endif
// Normal case: just files. // Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str); VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed. ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) { if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states); DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else { } else {
VectorFst<StdArc> det_fst; VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states); DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes *fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway. // out of scope anyway.
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// fstbin/fstisstochastic.cc // fstbin/fstisstochastic.cc
// Copyright 2009-2011 Microsoft Corporation // Copyright 2009-2011 Microsoft Corporation
@ -42,50 +56,51 @@
// though not stochastic because we gave it an absurdly large delta. // though not stochastic because we gave it an absurdly large delta.
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Checks whether an FST is stochastic and exits with success if so.\n" "Checks whether an FST is stochastic and exits with success if "
"Prints out maximum error (in log units).\n" "so.\n"
"\n" "Prints out maximum error (in log units).\n"
"Usage: fstisstochastic [ in.fst ]\n"; "\n"
"Usage: fstisstochastic [ in.fst ]\n";
float delta = 0.01; float delta = 0.01;
bool test_in_log = true; bool test_in_log = true;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept."); po.Register("delta", &delta, "Maximum error to accept.");
po.Register("test-in-log", &test_in_log, po.Register(
"Test stochasticity in log semiring."); "test-in-log", &test_in_log, "Test stochasticity in log semiring.");
po.Read(argc, argv); po.Read(argc, argv);
if (po.NumArgs() > 1) { if (po.NumArgs() > 1) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_filename = po.GetOptArg(1); std::string fst_in_filename = po.GetOptArg(1);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename); Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
bool ans; bool ans;
StdArc::Weight min, max; StdArc::Weight min, max;
if (test_in_log) if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max); ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else else
ans = IsStochasticFst(*fst, delta, &min, &max); ans = IsStochasticFst(*fst, delta, &min, &max);
std::cout << min.Value() << " " << max.Value() << '\n'; std::cout << min.Value() << " " << max.Value() << '\n';
delete fst; delete fst;
if (ans) if (ans)
return 0; // success; return 0; // success;
else else
return 1; return 1;
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << e.what(); std::cerr << e.what();
return -1; return -1;
} }
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// fstbin/fstminimizeencoded.cc // fstbin/fstminimizeencoded.cc
// Copyright 2009-2011 Microsoft Corporation // Copyright 2009-2011 Microsoft Corporation
@ -33,42 +47,43 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
const char *usage = const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no " "Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n" "weight-pushing]\n"
"\n" "\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n"; "Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
float delta = kDelta; float delta = kDelta;
ParseOptions po(usage); ParseOptions po(usage);
po.Register("delta", &delta, po.Register("delta",
"Delta likelihood used for quantization of weights"); &delta,
po.Read(argc, argv); "Delta likelihood used for quantization of weights");
po.Read(argc, argv);
if (po.NumArgs() > 2) { if (po.NumArgs() > 2) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
std::string fst_in_filename = po.GetOptArg(1), std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2); fst_out_filename = po.GetOptArg(2);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename); VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
MinimizeEncoded(fst, delta); MinimizeEncoded(fst, delta);
WriteFstKaldi(*fst, fst_out_filename); WriteFstKaldi(*fst, fst_out_filename);
delete fst; delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0; return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// fstbin/fsttablecompose.cc // fstbin/fsttablecompose.cc
// Copyright 2009-2011 Microsoft Corporation // Copyright 2009-2011 Microsoft Corporation
@ -37,97 +51,104 @@
*/ */
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
using namespace fst; // NOLINT using namespace fst; // NOLINT
using kaldi::int32; using kaldi::int32;
/* /*
fsttablecompose should always give equivalent results to compose, fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs. but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST olabel). The same with the input symbols of the right-hand FST
is possible. is possible.
*/ */
const char *usage = const char *usage =
"Composition algorithm [between two FSTs of standard type, in " "Composition algorithm [between two FSTs of standard type, in "
"tropical\n" "tropical\n"
"semiring] that is more efficient for certain cases-- in particular,\n" "semiring] that is more efficient for certain cases-- in "
"where one of the FSTs (the left one, if --match-side=left) has large\n" "particular,\n"
"out-degree\n" "where one of the FSTs (the left one, if --match-side=left) has "
"\n" "large\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) " "out-degree\n"
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n"; "\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
ParseOptions po(usage); "(fst2-rxfilename|fst2-rspecifier) "
"[(out-rxfilename|out-rspecifier)]\n";
TableComposeOptions opts;
std::string match_side = "left"; ParseOptions po(usage);
std::string compose_filter = "sequence";
TableComposeOptions opts;
po.Register("connect", &opts.connect, "If true, trim FST before output."); std::string match_side = "left";
po.Register("match-side", &match_side, std::string compose_filter = "sequence";
"Side of composition to do table "
"match, one of: \"left\" or \"right\"."); po.Register(
po.Register("compose-filter", &compose_filter, "connect", &opts.connect, "If true, trim FST before output.");
"Composition filter to use, " po.Register("match-side",
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\""); &match_side,
"Side of composition to do table "
po.Read(argc, argv); "match, one of: \"left\" or \"right\".");
po.Register(
if (match_side == "left") { "compose-filter",
opts.table_match_type = MATCH_OUTPUT; &compose_filter,
} else if (match_side == "right") { "Composition filter to use, "
opts.table_match_type = MATCH_INPUT; "one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
} else {
KALDI_ERR << "Invalid match-side option: " << match_side; po.Read(argc, argv);
}
if (match_side == "left") {
if (compose_filter == "alt_sequence") { opts.table_match_type = MATCH_OUTPUT;
opts.filter_type = ALT_SEQUENCE_FILTER; } else if (match_side == "right") {
} else if (compose_filter == "auto") { opts.table_match_type = MATCH_INPUT;
opts.filter_type = AUTO_FILTER; } else {
} else if (compose_filter == "match") { KALDI_ERR << "Invalid match-side option: " << match_side;
opts.filter_type = MATCH_FILTER; }
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER; if (compose_filter == "alt_sequence") {
} else { opts.filter_type = ALT_SEQUENCE_FILTER;
KALDI_ERR << "Invalid compose-filter option: " << compose_filter; } else if (compose_filter == "auto") {
} opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
if (po.NumArgs() < 2 || po.NumArgs() > 3) { opts.filter_type = MATCH_FILTER;
po.PrintUsage(); } else if (compose_filter == "sequence") {
exit(1); opts.filter_type = SEQUENCE_FILTER;
} } else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2), }
fst_out_str = po.GetOptArg(3);
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str); po.PrintUsage();
exit(1);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str); }
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted. std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
if (fst1->Properties(fst::kOLabelSorted, true) == 0) { fst_out_str = po.GetOptArg(3);
KALDI_WARN << "The first FST is not olabel sorted.";
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// bin/arpa2fst.cc // bin/arpa2fst.cc
// //
// Copyright 2009-2011 Gilles Boulianne. // Copyright 2009-2011 Gilles Boulianne.
@ -24,122 +38,130 @@
#include "util/parse-options.h" #include "util/parse-options.h"
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT using namespace kaldi; // NOLINT
try { try {
const char *usage = const char *usage =
"Convert an ARPA format language model into an FST\n" "Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n" "Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table=" " e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n" "data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n" "Note: When called without switches, the output G.fst will "
"an embedded symbol table. This is compatible with the way a previous\n" "contain\n"
"version of arpa2fst worked.\n"; "an embedded symbol table. This is compatible with the way a "
"previous\n"
ParseOptions po(usage); "version of arpa2fst worked.\n";
ArpaParseOptions options; ParseOptions po(usage);
options.Register(&po);
ArpaParseOptions options;
// Option flags. options.Register(&po);
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>"; // Option flags.
std::string disambig_symbol; std::string bos_symbol = "<s>";
std::string read_syms_filename; std::string eos_symbol = "</s>";
std::string write_syms_filename; std::string disambig_symbol;
bool keep_symbols = false; std::string read_syms_filename;
bool ilabel_sort = true; std::string write_syms_filename;
bool keep_symbols = false;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol"); bool ilabel_sort = true;
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol, po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
"Disambiguator. If provided (e. g. #0), used on input side of " po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
"backoff links, and <s> and </s> are replaced with epsilons"); po.Register(
po.Register("read-symbol-table", &read_syms_filename, "disambig-symbol",
"Use existing symbol table"); &disambig_symbol,
po.Register("write-symbol-table", &write_syms_filename, "Disambiguator. If provided (e. g. #0), used on input side of "
"Write generated symbol table to a file"); "backoff links, and <s> and </s> are replaced with epsilons");
po.Register("keep-symbols", &keep_symbols, po.Register("read-symbol-table",
"Store symbol table with FST. Symbols always saved to FST if " &read_syms_filename,
"symbol tables are neither read or written (otherwise symbols " "Use existing symbol table");
"would be lost entirely)"); po.Register("write-symbol-table",
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST"); &write_syms_filename,
"Write generated symbol table to a file");
po.Read(argc, argv); po.Register(
"keep-symbols",
if (po.NumArgs() != 1 && po.NumArgs() != 2) { &keep_symbols,
po.PrintUsage(); "Store symbol table with FST. Symbols always saved to FST if "
exit(1); "symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
} }
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
} }

Loading…
Cancel
Save