From cb39777a60e53cfbac4dd2382a28ce7ce10c8cef Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 31 Mar 2022 12:24:23 +0000 Subject: [PATCH] format code --- paddlespeech/s2t/frontend/speech.py | 7 ++++- paddlespeech/server/bin/main.py | 2 +- .../server/engine/asr/online/asr_engine.py | 27 +++++++++---------- paddlespeech/server/utils/buffer.py | 8 +++--- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/paddlespeech/s2t/frontend/speech.py b/paddlespeech/s2t/frontend/speech.py index 0340831a6..969971047 100644 --- a/paddlespeech/s2t/frontend/speech.py +++ b/paddlespeech/s2t/frontend/speech.py @@ -108,7 +108,12 @@ class SpeechSegment(AudioSegment): token_ids) @classmethod - def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None): + def from_pcm(cls, + samples, + sample_rate, + transcript, + tokens=None, + token_ids=None): """Create speech segment from pcm on online mode Args: samples (numpy.ndarray): Audio samples [num_samples x num_channels]. diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py index 45ded33d8..81824c85c 100644 --- a/paddlespeech/server/bin/main.py +++ b/paddlespeech/server/bin/main.py @@ -18,8 +18,8 @@ from fastapi import FastAPI from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router as setup_http_router -from paddlespeech.server.ws.api import setup_router as setup_ws_router from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index d5c1aa7bd..389175a0a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,29 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os -import time from typing import Optional -import pickle -import numpy as np -from numpy import float32 -import soundfile +import numpy as np import paddle +from numpy import float32 from yacs.config import CfgNode -from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.paddle_predictor import init_predictor -from paddlespeech.server.utils.paddle_predictor import run_model __all__ = ['ASREngine'] @@ -141,10 +135,10 @@ class ASRServerExecutor(ASRExecutor): reduction=True, # sum batch_average=True, # sum / batch_size grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - + # init decoder cfg = self.config.decode - decode_batch_size = 1 # for online + decode_batch_size = 1 # for online self.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, @@ -182,10 +176,11 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ - if "deepspeech2online" in model_type : + if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -203,7 +198,8 @@ class ASRServerExecutor(ASRExecutor): output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) - output_lens_handle = self.am_predictor.get_output_handle(output_names[1]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( @@ -341,7 +337,8 @@ class ASREngine(BaseEngine): x_chunk_lens (numpy.array): shape[B] decoder_chunk_size(int) """ - self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type) + self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, + self.config.model_type) def postprocess(self): """postprocess diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 4c1a3958a..682357b34 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -43,10 +43,10 @@ class ChunkBuffer(object): audio = self.remained_audio + audio self.remained_audio = b'' - n = int(self.sample_rate * - (self.frame_duration_ms / 1000.0) * self.sample_width) - shift_n = int(self.sample_rate * - (self.shift_ms / 1000.0) * self.sample_width) + n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) * + self.sample_width) + shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) * + self.sample_width) offset = 0 timestamp = 0.0 duration = (float(n) / self.sample_rate) / self.sample_width