From 760e5d4418702f88f9dbe3dfc6bedc22e99b7a03 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 31 Mar 2022 12:24:15 +0000 Subject: [PATCH 1/5] refactor cache --- .../examples/decoder/offline_decoder_main.cc | 5 +- .../examples/feat/linear_spectrogram_main.cc | 7 ++- speechx/speechx/frontend/CMakeLists.txt | 4 +- .../frontend/{raw_audio.cc => audio_cache.cc} | 43 +++++++------ speechx/speechx/frontend/audio_cache.h | 61 +++++++++++++++++++ .../frontend/{raw_audio.h => data_cache.h} | 46 +++----------- 6 files changed, 101 insertions(+), 65 deletions(-) rename speechx/speechx/frontend/{raw_audio.cc => audio_cache.cc} (64%) create mode 100644 speechx/speechx/frontend/audio_cache.h rename speechx/speechx/frontend/{raw_audio.h => data_cache.h} (54%) diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index 3a858ad16..c73d59682 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -17,7 +17,7 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" -#include "frontend/raw_audio.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -60,8 +60,7 @@ int main(int argc, char* argv[]) { model_opts.params_path = model_params; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); LOG(INFO) << "Init decodeable."; diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc index cde78c4d3..e1f0a8954 100644 --- a/speechx/examples/feat/linear_spectrogram_main.cc +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -17,10 +17,11 @@ #include "frontend/linear_spectrogram.h" #include "base/flags.h" #include "base/log.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" #include "frontend/feature_cache.h" #include "frontend/feature_extractor_interface.h" #include "frontend/normalizer.h" -#include "frontend/raw_audio.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" @@ -170,9 +171,9 @@ int main(int argc, char* argv[]) { // window -->linear_spectrogram --> global cmvn -> feat cache // std::unique_ptr data_source(new - // ppspeech::RawDataCache()); + // ppspeech::DataCache()); std::unique_ptr data_source( - new ppspeech::RawAudioCache()); + new ppspeech::AudioCache()); ppspeech::DecibelNormalizerOptions db_norm_opt; std::unique_ptr db_norm( diff --git a/speechx/speechx/frontend/CMakeLists.txt b/speechx/speechx/frontend/CMakeLists.txt index 44ca52cdc..d0ec008ee 100644 --- a/speechx/speechx/frontend/CMakeLists.txt +++ b/speechx/speechx/frontend/CMakeLists.txt @@ -3,8 +3,8 @@ project(frontend) add_library(frontend STATIC normalizer.cc linear_spectrogram.cc - raw_audio.cc + audio_cache.cc feature_cache.cc ) -target_link_libraries(frontend PUBLIC kaldi-matrix) +target_link_libraries(frontend PUBLIC kaldi-matrix) \ No newline at end of file diff --git a/speechx/speechx/frontend/raw_audio.cc b/speechx/speechx/frontend/audio_cache.cc similarity index 64% rename from speechx/speechx/frontend/raw_audio.cc rename to speechx/speechx/frontend/audio_cache.cc index 21f643628..d44ed592c 100644 --- a/speechx/speechx/frontend/raw_audio.cc +++ b/speechx/speechx/frontend/audio_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/raw_audio.h" +#include "frontend/audio_cache.h" #include "kaldi/base/timer.h" namespace ppspeech { @@ -21,38 +21,43 @@ using kaldi::BaseFloat; using kaldi::VectorBase; using kaldi::Vector; -RawAudioCache::RawAudioCache(int buffer_size) - : finished_(false), data_length_(0), start_(0), timeout_(1) { - ring_buffer_.resize(buffer_size); +AudioCache::AudioCache(int buffer_size) + : finished_(false), + capacity_(buffer_size), + size_(0), + offset_(0), + timeout_(1) { + ring_buffer_.resize(capacity_); } -void RawAudioCache::Accept(const VectorBase& waves) { +void AudioCache::Accept(const VectorBase& waves) { std::unique_lock lock(mutex_); - while (data_length_ + waves.Dim() > ring_buffer_.size()) { + while (size_ + waves.Dim() > ring_buffer_.size()) { ready_feed_condition_.wait(lock); } for (size_t idx = 0; idx < waves.Dim(); ++idx) { - int32 buffer_idx = (idx + start_) % ring_buffer_.size(); + int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); ring_buffer_[buffer_idx] = waves(idx); } - data_length_ += waves.Dim(); + size_ += waves.Dim(); } -bool RawAudioCache::Read(Vector* waves) { +bool AudioCache::Read(Vector* waves) { size_t chunk_size = waves->Dim(); kaldi::Timer timer; std::unique_lock lock(mutex_); - while (chunk_size > data_length_) { + while (chunk_size > size_) { // when audio is empty and no more data feed - // ready_read_condition will block in dead lock. so replace with - // timeout_ + // ready_read_condition will block in dead lock, + // so replace with timeout_ // ready_read_condition_.wait(lock); int32 elapsed = static_cast(timer.Elapsed() * 1000); if (elapsed > timeout_) { - if (finished_ == true) { // read last chunk data + if (finished_ == true) { + // read last chunk data break; } - if (chunk_size > data_length_) { + if (chunk_size > size_) { return false; } } @@ -60,17 +65,17 @@ bool RawAudioCache::Read(Vector* waves) { } // read last chunk data - if (chunk_size > data_length_) { - chunk_size = data_length_; + if (chunk_size > size_) { + chunk_size = size_; waves->Resize(chunk_size); } for (size_t idx = 0; idx < chunk_size; ++idx) { - int buff_idx = (start_ + idx) % ring_buffer_.size(); + int buff_idx = (offset_ + idx) % ring_buffer_.size(); waves->Data()[idx] = ring_buffer_[buff_idx]; } - data_length_ -= chunk_size; - start_ = (start_ + chunk_size) % ring_buffer_.size(); + size_ -= chunk_size; + offset_ = (offset_ + chunk_size) % ring_buffer_.size(); ready_feed_condition_.notify_one(); return true; } diff --git a/speechx/speechx/frontend/audio_cache.h b/speechx/speechx/frontend/audio_cache.h new file mode 100644 index 000000000..b6c82c69e --- /dev/null +++ b/speechx/speechx/frontend/audio_cache.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "base/common.h" +#include "frontend/feature_extractor_interface.h" + +namespace ppspeech { + +// waves cache +class AudioCache : public FeatureExtractorInterface { + public: + explicit AudioCache(int buffer_size = kint16max); + + virtual void Accept(const kaldi::VectorBase& waves); + + virtual bool Read(kaldi::Vector* waves); + + // the audio dim is 1, one sample + virtual size_t Dim() const { return 1; } + + virtual void SetFinished() { + std::lock_guard lock(mutex_); + finished_ = true; + } + + virtual bool IsFinished() const { return finished_; } + + virtual void Reset() { + offset_ = 0; + size_ = 0; + finished_ = false; + } + + private: + std::vector ring_buffer_; + size_t offset_; // offset in ring_buffer_ + size_t size_; // samples in ring_buffer_ now + size_t capacity_; // capacity of ring_buffer_ + bool finished_; // reach audio end + mutable std::mutex mutex_; + std::condition_variable ready_feed_condition_; + kaldi::int32 timeout_; // millisecond + + DISALLOW_COPY_AND_ASSIGN(AudioCache); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/raw_audio.h b/speechx/speechx/frontend/data_cache.h similarity index 54% rename from speechx/speechx/frontend/raw_audio.h rename to speechx/speechx/frontend/data_cache.h index 7a28f2c98..dea51d76e 100644 --- a/speechx/speechx/frontend/raw_audio.h +++ b/speechx/speechx/frontend/data_cache.h @@ -15,51 +15,22 @@ #pragma once + #include "base/common.h" #include "frontend/feature_extractor_interface.h" -#pragma once namespace ppspeech { - -class RawAudioCache : public FeatureExtractorInterface { +// A data source for testing different frontend module. +// It accepts waves or feats. +class DataCache : public FeatureExtractorInterface { public: - explicit RawAudioCache(int buffer_size = kint16max); - virtual void Accept(const kaldi::VectorBase& waves); - virtual bool Read(kaldi::Vector* waves); - // the audio dim is 1 - virtual size_t Dim() const { return 1; } - virtual void SetFinished() { - std::lock_guard lock(mutex_); - finished_ = true; - } - virtual bool IsFinished() const { return finished_; } - virtual void Reset() { - start_ = 0; - data_length_ = 0; - finished_ = false; - } - - private: - std::vector ring_buffer_; - size_t start_; - size_t data_length_; - bool finished_; - mutable std::mutex mutex_; - std::condition_variable ready_feed_condition_; - kaldi::int32 timeout_; - - DISALLOW_COPY_AND_ASSIGN(RawAudioCache); -}; + explicit DataCache() { finished_ = false; } -// it is a datasource for testing different frontend module. -// it accepts waves or feats. -class RawDataCache : public FeatureExtractorInterface { - public: - explicit RawDataCache() { finished_ = false; } virtual void Accept(const kaldi::VectorBase& inputs) { data_ = inputs; } + virtual bool Read(kaldi::Vector* feats) { if (data_.Dim() == 0) { return false; @@ -80,7 +51,6 @@ class RawDataCache : public FeatureExtractorInterface { bool finished_; int32 dim_; - DISALLOW_COPY_AND_ASSIGN(RawDataCache); + DISALLOW_COPY_AND_ASSIGN(DataCache); }; - -} // namespace ppspeech +} \ No newline at end of file From cb39777a60e53cfbac4dd2382a28ce7ce10c8cef Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 31 Mar 2022 12:24:23 +0000 Subject: [PATCH 2/5] format code --- paddlespeech/s2t/frontend/speech.py | 7 ++++- paddlespeech/server/bin/main.py | 2 +- .../server/engine/asr/online/asr_engine.py | 27 +++++++++---------- paddlespeech/server/utils/buffer.py | 8 +++--- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/paddlespeech/s2t/frontend/speech.py b/paddlespeech/s2t/frontend/speech.py index 0340831a6..969971047 100644 --- a/paddlespeech/s2t/frontend/speech.py +++ b/paddlespeech/s2t/frontend/speech.py @@ -108,7 +108,12 @@ class SpeechSegment(AudioSegment): token_ids) @classmethod - def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None): + def from_pcm(cls, + samples, + sample_rate, + transcript, + tokens=None, + token_ids=None): """Create speech segment from pcm on online mode Args: samples (numpy.ndarray): Audio samples [num_samples x num_channels]. diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py index 45ded33d8..81824c85c 100644 --- a/paddlespeech/server/bin/main.py +++ b/paddlespeech/server/bin/main.py @@ -18,8 +18,8 @@ from fastapi import FastAPI from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router as setup_http_router -from paddlespeech.server.ws.api import setup_router as setup_ws_router from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index d5c1aa7bd..389175a0a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,29 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os -import time from typing import Optional -import pickle -import numpy as np -from numpy import float32 -import soundfile +import numpy as np import paddle +from numpy import float32 from yacs.config import CfgNode -from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.paddle_predictor import init_predictor -from paddlespeech.server.utils.paddle_predictor import run_model __all__ = ['ASREngine'] @@ -141,10 +135,10 @@ class ASRServerExecutor(ASRExecutor): reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - + # init decoder cfg = self.config.decode - decode_batch_size = 1 # for online + decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, @@ -182,10 +176,11 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ - if "deepspeech2online" in model_type : + if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -203,7 +198,8 @@ class ASRServerExecutor(ASRExecutor): output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) - output_lens_handle = self.am_predictor.get_output_handle(output_names[1]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( @@ -341,7 +337,8 @@ class ASREngine(BaseEngine): x_chunk_lens (numpy.array): shape[B] decoder_chunk_size(int) """ - self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) + self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, + self.config.model_type) def postprocess(self): """postprocess diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 4c1a3958a..682357b34 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -43,10 +43,10 @@ class ChunkBuffer(object): audio = self.remained_audio + audio self.remained_audio = b'' - n = int(self.sample_rate * - (self.frame_duration_ms / 1000.0) * self.sample_width) - shift_n = int(self.sample_rate * - (self.shift_ms / 1000.0) * self.sample_width) + n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) * + self.sample_width) + shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) * + self.sample_width) offset = 0 timestamp = 0.0 duration = (float(n) / self.sample_rate) / self.sample_width From 92d699c1f820090ae5575f37267c464ba8b4081d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 1 Apr 2022 03:02:14 +0000 Subject: [PATCH 3/5] fix raw data --- .../examples/decoder/offline_decoder_sliding_chunk_main.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index ad72b7723..27bd7b1bc 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -17,7 +17,7 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" -#include "frontend/raw_audio.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -57,8 +57,8 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr raw_data( + new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); From 93c3e03bc846053889c823a349d7921686a329c3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 1 Apr 2022 07:49:53 +0000 Subject: [PATCH 4/5] more comment --- speechx/examples/decoder/decoder_test_main.cc | 7 ++- .../examples/decoder/offline_decoder_main.cc | 2 +- .../offline_decoder_sliding_chunk_main.cc | 48 ++++++++++++++----- speechx/speechx/nnet/decodable.cc | 2 +- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/speechx/examples/decoder/decoder_test_main.cc b/speechx/examples/decoder/decoder_test_main.cc index 79fe63fcd..0e249cc6b 100644 --- a/speechx/examples/decoder/decoder_test_main.cc +++ b/speechx/examples/decoder/decoder_test_main.cc @@ -24,11 +24,11 @@ DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding nnet posterior probability int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -37,6 +37,8 @@ int main(int argc, char* argv[]) { FLAGS_nnet_prob_respecifier); std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; int32 num_done = 0, num_err = 0; @@ -53,6 +55,9 @@ int main(int argc, char* argv[]) { for (; !likelihood_reader.Done(); likelihood_reader.Next()) { string utt = likelihood_reader.Key(); const kaldi::Matrix likelihood = likelihood_reader.Value(); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << likelihood.NumRows(); + LOG(INFO) << "cols: " << likelihood.NumCols(); decodable->Acceptlikelihood(likelihood); decoder.AdvanceDecode(decodable); std::string result; diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index c73d59682..6bd83b9b1 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -34,6 +34,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding speech feature, deprecated. int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -55,7 +56,6 @@ int main(int argc, char* argv[]) { // frontend + nnet is decodable ppspeech::ModelOptions model_opts; - model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.model_path = model_graph; model_opts.params_path = model_params; std::shared_ptr nnet( diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index 27bd7b1bc..4d5ffe145 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -27,12 +27,19 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; + +// test ds2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -43,6 +50,11 @@ int main(int argc, char* argv[]) { std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; + int32 num_done = 0, num_err = 0; @@ -57,34 +69,44 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::DataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); - int32 chunk_size = 7; - int32 chunk_stride = 4; - int32 receptive_field_length = 7; + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); kaldi::Matrix feature = feature_reader.Value(); raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + int32 row_idx = 0; int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ( (feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, feature.NumCols(), kaldi::kCopyData); + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); } int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { kaldi::Vector feature_chunk(chunk_size * feature.NumCols()); - int32 feature_chunk_size = 0; - if ( ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min(ori_feature_len - chunk_idx * chunk_stride, chunk_size); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (feature_chunk_size < receptive_field_length) break; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index cd72bf767..e6315d07a 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -82,7 +82,7 @@ void Decodable::Reset() { if (nnet_ != nullptr) nnet_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_cache_.Resize(0,0); + nnet_cache_.Resize(0, 0); } } // namespace ppspeech \ No newline at end of file From fea437abe3f6dc548408fcc92db36da769aaaf9b Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 1 Apr 2022 08:19:47 +0000 Subject: [PATCH 5/5] add glog test --- .gitignore | 2 +- .mergify.yml | 2 +- speechx/examples/decoder/local/model.sh | 3 +++ speechx/examples/glog/CMakeLists.txt | 8 ++++++ speechx/examples/glog/README.md | 25 +++++++++++++++++++ .../examples/glog/glog_logtostderr_test.cc | 25 +++++++++++++++++++ speechx/examples/glog/glog_test.cc | 23 +++++++++++++++++ speechx/examples/glog/path.sh | 14 +++++++++++ speechx/examples/glog/run.sh | 22 ++++++++++++++++ 9 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 speechx/examples/decoder/local/model.sh create mode 100644 speechx/examples/glog/CMakeLists.txt create mode 100644 speechx/examples/glog/README.md create mode 100644 speechx/examples/glog/glog_logtostderr_test.cc create mode 100644 speechx/examples/glog/glog_test.cc create mode 100644 speechx/examples/glog/path.sh create mode 100755 speechx/examples/glog/run.sh diff --git a/.gitignore b/.gitignore index e25ec327b..639472001 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store *.pyc .vscode -*log +*.log *.wav *.pdmodel *.pdiparams* diff --git a/.mergify.yml b/.mergify.yml index 6dae66d04..68b248101 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -52,7 +52,7 @@ pull_request_rules: add: ["T2S"] - name: "auto add label=Audio" conditions: - - files~=^audio/ + - files~=^paddleaudio/ actions: label: add: ["Audio"] diff --git a/speechx/examples/decoder/local/model.sh b/speechx/examples/decoder/local/model.sh new file mode 100644 index 000000000..5c609a6cf --- /dev/null +++ b/speechx/examples/decoder/local/model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + + diff --git a/speechx/examples/glog/CMakeLists.txt b/speechx/examples/glog/CMakeLists.txt new file mode 100644 index 000000000..b4b0e6358 --- /dev/null +++ b/speechx/examples/glog/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(glog_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_test.cc) +target_link_libraries(glog_test glog) + + +add_executable(glog_logtostderr_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_test.cc) +target_link_libraries(glog_logtostderr_test glog) \ No newline at end of file diff --git a/speechx/examples/glog/README.md b/speechx/examples/glog/README.md new file mode 100644 index 000000000..996e192e9 --- /dev/null +++ b/speechx/examples/glog/README.md @@ -0,0 +1,25 @@ +# [GLOG](https://rpg.ifi.uzh.ch/docs/glog.html) + +Unless otherwise specified, glog writes to the filename `/tmp/...log...