diff --git a/.gitignore b/.gitignore index e25ec327b..639472001 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store *.pyc .vscode -*log +*.log *.wav *.pdmodel *.pdiparams* diff --git a/.mergify.yml b/.mergify.yml index 6dae66d04..68b248101 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -52,7 +52,7 @@ pull_request_rules: add: ["T2S"] - name: "auto add label=Audio" conditions: - - files~=^audio/ + - files~=^paddleaudio/ actions: label: add: ["Audio"] diff --git a/paddlespeech/s2t/frontend/speech.py b/paddlespeech/s2t/frontend/speech.py index 0340831a6..969971047 100644 --- a/paddlespeech/s2t/frontend/speech.py +++ b/paddlespeech/s2t/frontend/speech.py @@ -108,7 +108,12 @@ class SpeechSegment(AudioSegment): token_ids) @classmethod - def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None): + def from_pcm(cls, + samples, + sample_rate, + transcript, + tokens=None, + token_ids=None): """Create speech segment from pcm on online mode Args: samples (numpy.ndarray): Audio samples [num_samples x num_channels]. diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py index 45ded33d8..81824c85c 100644 --- a/paddlespeech/server/bin/main.py +++ b/paddlespeech/server/bin/main.py @@ -18,8 +18,8 @@ from fastapi import FastAPI from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router as setup_http_router -from paddlespeech.server.ws.api import setup_router as setup_ws_router from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index d5c1aa7bd..389175a0a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,29 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os -import time from typing import Optional -import pickle -import numpy as np -from numpy import float32 -import soundfile +import numpy as np import paddle +from numpy import float32 from yacs.config import CfgNode -from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.paddle_predictor import init_predictor -from paddlespeech.server.utils.paddle_predictor import run_model __all__ = ['ASREngine'] @@ -141,10 +135,10 @@ class ASRServerExecutor(ASRExecutor): reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - + # init decoder cfg = self.config.decode - decode_batch_size = 1 # for online + decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, @@ -182,10 +176,11 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ - if "deepspeech2online" in model_type : + if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -203,7 +198,8 @@ class ASRServerExecutor(ASRExecutor): output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) - output_lens_handle = self.am_predictor.get_output_handle(output_names[1]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( @@ -341,7 +337,8 @@ class ASREngine(BaseEngine): x_chunk_lens (numpy.array): shape[B] decoder_chunk_size(int) """ - self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) + self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, + self.config.model_type) def postprocess(self): """postprocess diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 4c1a3958a..682357b34 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -43,10 +43,10 @@ class ChunkBuffer(object): audio = self.remained_audio + audio self.remained_audio = b'' - n = int(self.sample_rate * - (self.frame_duration_ms / 1000.0) * self.sample_width) - shift_n = int(self.sample_rate * - (self.shift_ms / 1000.0) * self.sample_width) + n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) * + self.sample_width) + shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) * + self.sample_width) offset = 0 timestamp = 0.0 duration = (float(n) / self.sample_rate) / self.sample_width diff --git a/speechx/examples/decoder/decoder_test_main.cc b/speechx/examples/decoder/decoder_test_main.cc index 79fe63fcd..0e249cc6b 100644 --- a/speechx/examples/decoder/decoder_test_main.cc +++ b/speechx/examples/decoder/decoder_test_main.cc @@ -24,11 +24,11 @@ DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding nnet posterior probability int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -37,6 +37,8 @@ int main(int argc, char* argv[]) { FLAGS_nnet_prob_respecifier); std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; int32 num_done = 0, num_err = 0; @@ -53,6 +55,9 @@ int main(int argc, char* argv[]) { for (; !likelihood_reader.Done(); likelihood_reader.Next()) { string utt = likelihood_reader.Key(); const kaldi::Matrix likelihood = likelihood_reader.Value(); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << likelihood.NumRows(); + LOG(INFO) << "cols: " << likelihood.NumCols(); decodable->Acceptlikelihood(likelihood); decoder.AdvanceDecode(decodable); std::string result; diff --git a/speechx/examples/decoder/local/model.sh b/speechx/examples/decoder/local/model.sh new file mode 100644 index 000000000..5c609a6cf --- /dev/null +++ b/speechx/examples/decoder/local/model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + + diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index 3a94cc947..6bd83b9b1 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -17,7 +17,7 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" -#include "frontend/raw_audio.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -34,6 +34,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding speech feature, deprecated. int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -59,8 +60,7 @@ int main(int argc, char* argv[]) { model_opts.params_path = model_params; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); LOG(INFO) << "Init decodeable."; diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index ad72b7723..4d5ffe145 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -17,7 +17,7 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" -#include "frontend/raw_audio.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -27,12 +27,19 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; + +// test ds2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -43,6 +50,11 @@ int main(int argc, char* argv[]) { std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; + int32 num_done = 0, num_err = 0; @@ -57,34 +69,44 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); - int32 chunk_size = 7; - int32 chunk_stride = 4; - int32 receptive_field_length = 7; + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); kaldi::Matrix feature = feature_reader.Value(); raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + int32 row_idx = 0; int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ( (feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, feature.NumCols(), kaldi::kCopyData); + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); } int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { kaldi::Vector feature_chunk(chunk_size * feature.NumCols()); - int32 feature_chunk_size = 0; - if ( ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min(ori_feature_len - chunk_idx * chunk_stride, chunk_size); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (feature_chunk_size < receptive_field_length) break; diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc index cde78c4d3..e1f0a8954 100644 --- a/speechx/examples/feat/linear_spectrogram_main.cc +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -17,10 +17,11 @@ #include "frontend/linear_spectrogram.h" #include "base/flags.h" #include "base/log.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" #include "frontend/feature_cache.h" #include "frontend/feature_extractor_interface.h" #include "frontend/normalizer.h" -#include "frontend/raw_audio.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" @@ -170,9 +171,9 @@ int main(int argc, char* argv[]) { // window -->linear_spectrogram --> global cmvn -> feat cache // std::unique_ptr data_source(new - // ppspeech::RawDataCache()); + // ppspeech::DataCache()); std::unique_ptr data_source( - new ppspeech::RawAudioCache()); + new ppspeech::AudioCache()); ppspeech::DecibelNormalizerOptions db_norm_opt; std::unique_ptr db_norm( diff --git a/speechx/examples/glog/CMakeLists.txt b/speechx/examples/glog/CMakeLists.txt new file mode 100644 index 000000000..b4b0e6358 --- /dev/null +++ b/speechx/examples/glog/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(glog_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_test.cc) +target_link_libraries(glog_test glog) + + +add_executable(glog_logtostderr_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_test.cc) +target_link_libraries(glog_logtostderr_test glog) \ No newline at end of file diff --git a/speechx/examples/glog/README.md b/speechx/examples/glog/README.md new file mode 100644 index 000000000..996e192e9 --- /dev/null +++ b/speechx/examples/glog/README.md @@ -0,0 +1,25 @@ +# [GLOG](https://rpg.ifi.uzh.ch/docs/glog.html) + +Unless otherwise specified, glog writes to the filename `/tmp/...log...