Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into refactor_decoder

pull/1633/head
Yang Zhou 4 years ago
commit 9b7189404f

@ -4,6 +4,7 @@
The directory containes many speech applications in multi scenarios. The directory containes many speech applications in multi scenarios.
* audio searching - mass audio similarity retrieval
* audio tagging - multi-label tagging of an audio file * audio tagging - multi-label tagging of an audio file
* automatic_video_subtitiles - generate subtitles from a video * automatic_video_subtitiles - generate subtitles from a video
* metaverse - 2D AR with TTS * metaverse - 2D AR with TTS

@ -4,6 +4,7 @@
该目录包含基于 PaddleSpeech 开发的不同场景的语音应用 Demo 该目录包含基于 PaddleSpeech 开发的不同场景的语音应用 Demo
* 声音检索 - 海量音频相似性检索。
* 声音分类 - 基于 AudioSet 的 527 类标签的音频多标签分类。 * 声音分类 - 基于 AudioSet 的 527 类标签的音频多标签分类。
* 视频字幕生成 - 识别视频中语音的文本,并进行文本后处理。 * 视频字幕生成 - 识别视频中语音的文本,并进行文本后处理。
* 元宇宙 - 基于语音合成的 2D 增强现实。 * 元宇宙 - 基于语音合成的 2D 增强现实。

@ -64,7 +64,7 @@ services:
webclient: webclient:
container_name: audio-webclient container_name: audio-webclient
image: qingen1/paddlespeech-audio-search-client:2.3 image: paddlepaddle/paddlespeech-audio-search-client:2.3
networks: networks:
app_net: app_net:
ipv4_address: 172.16.23.13 ipv4_address: 172.16.23.13

@ -11,11 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import librosa
import numpy as np import numpy as np
from config import DEFAULT_TABLE
from logs import LOGGER from logs import LOGGER
from paddlespeech.cli import VectorExecutor from paddlespeech.cli import VectorExecutor

@ -23,7 +23,3 @@ process:
n_mask: 2 n_mask: 2
inplace: true inplace: true
replace_with_zero: false replace_with_zero: false

@ -227,7 +227,9 @@ Pretrained FastSpeech2 model with no silence in the edge of audios:
- [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip) - [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)
The static model can be downloaded here [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip). The static model can be downloaded here:
- [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip)
- [fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:

@ -208,6 +208,18 @@ class AudioSegment():
io.BytesIO(bytes), dtype='float32') io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod
def from_pcm(cls, samples, sample_rate):
"""Create audio segment from a byte string containing audio samples.
:param samples: Audio samples [num_samples x num_channels].
:type samples: numpy.ndarray
:param sample_rate: Audio sample rate.
:type sample_rate: int
:return: Audio segment instance.
:rtype: AudioSegment
"""
return cls(samples, sample_rate)
@classmethod @classmethod
def concatenate(cls, *segments): def concatenate(cls, *segments):
"""Concatenate an arbitrary number of audio segments together. """Concatenate an arbitrary number of audio segments together.

@ -107,6 +107,22 @@ class SpeechSegment(AudioSegment):
return cls(audio.samples, audio.sample_rate, transcript, tokens, return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids) token_ids)
@classmethod
def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None):
"""Create speech segment from pcm on online mode
Args:
samples (numpy.ndarray): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio = AudioSegment.from_pcm(samples, sample_rate)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod @classmethod
def concatenate(cls, *segments): def concatenate(cls, *segments):
"""Concatenate an arbitrary number of speech segments together, both """Concatenate an arbitrary number of speech segments together, both

@ -14,8 +14,11 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
import librosa import librosa
import numpy as np import numpy as np
import paddle
from python_speech_features import logfbank from python_speech_features import logfbank
import paddleaudio.compliance.kaldi as kaldi
def stft(x, def stft(x,
n_fft, n_fft,
@ -309,6 +312,77 @@ class IStft():
class LogMelSpectrogramKaldi(): class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.1):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
mat = kaldi.fbank(
waveform,
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=dither,
energy_floor=self.energy_floor,
sr=self.fs)
mat = np.squeeze(mat.numpy())
return mat
class LogMelSpectrogramKaldi_decay():
def __init__( def __init__(
self, self,
fs=16000, fs=16000,

@ -31,6 +31,7 @@ import_alias = dict(
freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask", freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask",
spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment", spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment",
speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation", speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation",
speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox",
volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation", volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation",
noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection", noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection",
bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation", bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation",

@ -17,7 +17,8 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.ws.api import setup_router as setup_ws_router
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
app = FastAPI( app = FastAPI(
@ -35,7 +36,12 @@ def init(config):
""" """
# init api # init api
api_list = list(engine.split("_")[0] for engine in config.engine_list) api_list = list(engine.split("_")[0] for engine in config.engine_list)
api_router = setup_router(api_list) if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
app.include_router(api_router) app.include_router(api_router)
if not init_engine_pool(config): if not init_engine_pool(config):

@ -150,7 +150,7 @@ class TTSClientExecutor(BaseExecutor):
res = requests.post(url, json.dumps(request)) res = requests.post(url, json.dumps(request))
response_dict = res.json() response_dict = res.json()
if not output: if output is not None:
self.postprocess(response_dict["result"]["audio"], output) self.postprocess(response_dict["result"]["audio"], output)
return res return res

@ -8,7 +8,9 @@ port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['asr_python', 'tts_python', 'cls_python'] engine_list: ['asr_python', 'tts_python', 'cls_python']
@ -48,6 +50,24 @@ asr_inference:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################### TTS ######################################### ################################### TTS #########################################
################### speech task: tts; engine_type: python ####################### ################### speech task: tts; engine_type: python #######################
tts_python: tts_python:

@ -0,0 +1,51 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
vad_conf:
aggressiveness: 2
sample_rate: 16000
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,355 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import pickle
import numpy as np
from numpy import float32
import soundfile
import paddle
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'd5e076217cf60486519f72c217d21b9b',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
pass
def _init_from_path(self,
model_type: str='wenetspeech',
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
pretrained_models[tag]['params'])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type")
else:
raise Exception("wrong type")
# AM predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
# decoder
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio
"""
self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
[type]: [description]
"""
if "deepspeech2online" in model_type :
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name")
else:
raise Exception("invalid model name")
def _pcm16to32(self, audio):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if audio.dtype == np.int16:
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
return audio
def extract_feat(self, samples, sample_rate):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# pcm16 -> pcm 32
samples = self._pcm16to32(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine):
"""ASR server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = ""
self.executor = ASRServerExecutor()
self.config = config
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
def preprocess(self, samples, sample_rate):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
def postprocess(self):
"""postprocess
"""
return self.output
def reset(self):
"""reset engine decoder and inference state
"""
self.executor.reset_decoder_and_chunk()
self.output = ""

@ -25,6 +25,9 @@ class EngineFactory(object):
elif engine_name == 'asr' and engine_type == 'python': elif engine_name == 'asr' and engine_type == 'python':
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'tts' and engine_type == 'inference': elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
return TTSEngine() return TTSEngine()

@ -0,0 +1,161 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
record wave from the mic
"""
import asyncio
import json
import logging
import threading
import wave
from signal import SIGINT
from signal import SIGTERM
import pyaudio
import websockets
class ASRAudioHandler(threading.Thread):
def __init__(self, url="127.0.0.1", port=8091):
threading.Thread.__init__(self)
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
self.fileName = "./output.wav"
self.chunk = 5120
self.format = pyaudio.paInt16
self.channels = 1
self.rate = 16000
self._running = True
self._frames = []
self.data_backup = []
def startrecord(self):
"""
start a new thread to record wave
"""
threading._start_new_thread(self.recording, ())
def recording(self):
"""
recording wave
"""
self._running = True
self._frames = []
p = pyaudio.PyAudio()
stream = p.open(
format=self.format,
channels=self.channels,
rate=self.rate,
input=True,
frames_per_buffer=self.chunk)
while (self._running):
data = stream.read(self.chunk)
self._frames.append(data)
self.data_backup.append(data)
stream.stop_stream()
stream.close()
p.terminate()
def save(self):
"""
save wave data
"""
p = pyaudio.PyAudio()
wf = wave.open(self.fileName, 'wb')
wf.setnchannels(self.channels)
wf.setsampwidth(p.get_sample_size(self.format))
wf.setframerate(self.rate)
wf.writeframes(b''.join(self.data_backup))
wf.close()
p.terminate()
def stoprecord(self):
"""
stop recording
"""
self._running = False
async def run(self):
aa = input("是否开始录音? (y/n)")
if aa.strip() == "y":
self.startrecord()
logging.info("*" * 10 + "开始录音,请输入语音")
async with websockets.connect(self.url) as ws:
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "start",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# send bytes data
logging.info("结束录音请: Ctrl + c。继续请按回车。")
try:
while True:
while len(self._frames) > 0:
await ws.send(self._frames.pop(0))
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
except asyncio.CancelledError:
# quit
# send finished
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "end",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
self.stoprecord()
logging.info("*" * 10 + "录音结束")
self.save()
elif aa.strip() == "n":
exit()
else:
print("无效输入!")
exit()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(handler.run())
for signal in [SIGINT, SIGTERM]:
loop.add_signal_handler(signal, main_task.cancel)
try:
loop.run_until_complete(main_task)
finally:
loop.close()
logging.info("asr websocket client finished")

@ -0,0 +1,115 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import json
import logging
import numpy as np
import soundfile
import websockets
class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090):
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "start",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# finished
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "end",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--wavfile",
action="store",
help="wav file path ",
default="./16_audio.wav")
args = parser.parse_args()
main(args)

@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_width = sample_width # int16 = 2; float32 = 4
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate *
(self.frame_duration_ms / 1000.0) * self.sample_width)
shift_n = int(self.sample_rate *
(self.shift_ms / 1000.0) * self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n
self.remained_audio += audio[offset:]

@ -0,0 +1,78 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import webrtcvad
class VADAudio():
def __init__(self,
aggressiveness=2,
rate=16000,
frame_duration_ms=20,
sample_width=2,
padding_ms=200,
padding_ratio=0.9):
"""Initializes VAD with given aggressivenes and sets up internal queues"""
self.vad = webrtcvad.Vad(aggressiveness)
self.rate = rate
self.sample_width = sample_width
self.frame_duration_ms = frame_duration_ms
self._frame_length = int(rate * (frame_duration_ms / 1000.0) *
self.sample_width)
self._buffer_queue = collections.deque()
self.ring_buffer = collections.deque(maxlen=padding_ms //
frame_duration_ms)
self._ratio = padding_ratio
self.triggered = False
def add_audio(self, audio):
"""Adds new audio to internal queue"""
for x in audio:
self._buffer_queue.append(x)
def frame_generator(self):
"""Generator that yields audio frames of frame_duration_ms"""
while len(self._buffer_queue) > self._frame_length:
frame = bytearray()
for _ in range(self._frame_length):
frame.append(self._buffer_queue.popleft())
yield bytes(frame)
def vad_collector(self):
"""Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
|---utterence---| |---utterence---|
"""
for frame in self.frame_generator():
is_speech = self.vad.is_speech(frame, self.rate)
if not self.triggered:
self.ring_buffer.append((frame, is_speech))
num_voiced = len(
[f for f, speech in self.ring_buffer if speech])
if num_voiced > self._ratio * self.ring_buffer.maxlen:
self.triggered = True
for f, s in self.ring_buffer:
yield f
self.ring_buffer.clear()
else:
yield frame
self.ring_buffer.append((frame, is_speech))
num_unvoiced = len(
[f for f, speech in self.ring_buffer if not speech])
if num_unvoiced > self._ratio * self.ring_buffer.maxlen:
self.triggered = False
yield None
self.ring_buffer.clear()

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,38 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router
_router = APIRouter()
def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
Returns:
APIRouter
"""
for api_name in api_list:
if api_name == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
pass
else:
pass
return _router

@ -0,0 +1,100 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
router = APIRouter()
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
vad_conf = asr_engine.config.vad_conf
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try:
while True:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
if "text" in message:
message = json.loads(message["text"])
if 'signal' not in message:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
await websocket.send_json(resp)
break
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

@ -73,15 +73,21 @@ class LengthRegulator(nn.Layer):
batch_size, t_enc = paddle.shape(durations) batch_size, t_enc = paddle.shape(durations)
slens = paddle.sum(durations, -1) slens = paddle.sum(durations, -1)
t_dec = paddle.max(slens) t_dec = paddle.max(slens)
M = paddle.zeros([batch_size, t_dec, t_enc]) t_dec_1 = t_dec + 1
for i in range(batch_size): flatten_duration = paddle.cumsum(
k = 0 paddle.reshape(durations, [batch_size * t_enc])) + 1
for j in range(t_enc): init = paddle.zeros(t_dec_1)
d = durations[i, j] m_batch = batch_size * t_enc
# If the d == 0, slice action is meaningless and not supported in paddle M = paddle.zeros([t_dec_1, m_batch])
if d >= 1: for i in range(m_batch):
M[i, k:k + d, j] = 1 d = flatten_duration[i]
k += d m = paddle.concat(
[paddle.ones(d), paddle.zeros(t_dec_1 - d)], axis=0)
M[:, i] = m - init
init = m
M = paddle.reshape(M, shape=[t_dec_1, batch_size, t_enc])
M = M[1:, :, :]
M = paddle.transpose(M, (1, 0, 2))
encodings = paddle.matmul(M, encodings) encodings = paddle.matmul(M, encodings)
return encodings return encodings

@ -5,7 +5,7 @@
We develop under: We develop under:
* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 * docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7
* os - Ubuntu 16.04.7 LTS * os - Ubuntu 16.04.7 LTS
* ** gcc/g++/gfortran - 8.2.0 ** * gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0 * cmake - 3.16.0
> We make sure all things work fun under docker, and recommend using it to develop and deploy. > We make sure all things work fun under docker, and recommend using it to develop and deploy.
@ -24,7 +24,7 @@ nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspac
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html). * More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`. * If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nvidia-docker`.
2. Build `speechx` and `examples`. 2. Build `speechx` and `examples`.

@ -13,7 +13,7 @@ ExternalProject_Add(openfst
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread" "LIBS=-lgflags_nothreads -lglog -lpthread"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4 BUILD_COMMAND make -j 4
) )
link_directories(${openfst_PREFIX_DIR}/lib) link_directories(${openfst_PREFIX_DIR}/lib)

@ -3,3 +3,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat) add_subdirectory(feat)
add_subdirectory(nnet) add_subdirectory(nnet)
add_subdirectory(decoder) add_subdirectory(decoder)
add_subdirectory(glog)

@ -1,8 +1,9 @@
# Examples # Examples
* decoder - online decoder to work as offline * glog - glog usage
* feat - mfcc, linear * feat - mfcc, linear
* nnet - ds2 nn * nnet - ds2 nn
* decoder - online decoder to work as offline
## How to run ## How to run

@ -22,11 +22,12 @@
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test feature rspecifier"); DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model"); DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(chunk_size, 35, "feat chunk size");
using kaldi::BaseFloat; using kaldi::BaseFloat;
@ -43,14 +44,16 @@ int main(int argc, char* argv[]) {
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path; std::string lm_path = FLAGS_lm_path;
int32 chunk_size = FLAGS_chunk_size;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
LOG(INFO) << "chunk size (frame): " << chunk_size;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; // frontend + nnet is decodable
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.cache_shape = "5-1-1024,5-1-1024";
model_opts.model_path = model_graph; model_opts.model_path = model_graph;
@ -61,33 +64,50 @@ int main(int argc, char* argv[]) {
new ppspeech::RawDataCache()); new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data)); new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
int32 chunk_size = 35; // init decoder
decoder.InitDecoder(); ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
LOG(INFO) << "Init decoder.";
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
LOG(INFO) << "utt: " << utt;
// feat dim
raw_data->SetDim(feature.NumCols()); raw_data->SetDim(feature.NumCols());
LOG(INFO) << "dim: " << raw_data->Dim();
int32 row_idx = 0; int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size; int32 num_chunks = feature.NumRows() / chunk_size;
LOG(INFO) << "n chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
// feat chunk
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size * kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols()); feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx); kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp( kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(), feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols()); feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp); f_chunk_tmp.CopyFromVec(feat_one_row);
row_idx++; row_idx++;
} }
// feed to raw cache
raw_data->Accept(feature_chunk); raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) { if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished(); raw_data->SetFinished();
} }
// decode step
decoder.AdvanceDecode(decodable); decoder.AdvanceDecode(decodable);
} }
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result; KALDI_LOG << " the result of " << utt << " is " << result;

@ -25,7 +25,10 @@ model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark feat_wspecifier=./feats.ark
cmvn=./cmvn.ark cmvn=./cmvn.ark
# 3. run feat
export GLOG_logtostderr=1
# 3. gen linear feat
linear_spectrogram_main \ linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \ --wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \

@ -41,7 +41,6 @@
using namespace kaldi; using namespace kaldi;
static void UnitTestReadWave() { static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n"; std::cout << "=== UnitTestReadWave() ===\n";

@ -25,6 +25,8 @@
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include <glog/logging.h>
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
@ -149,7 +151,7 @@ void WriteMatrix() {
cmvn_stats(1, idx) = variance_[idx]; cmvn_stats(1, idx) = variance_[idx];
} }
cmvn_stats(0, mean_.size()) = count_; cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, false);
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
@ -161,43 +163,56 @@ int main(int argc, char* argv[]) {
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
WriteMatrix(); WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new // std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// ppspeech::RawDataCache()); // ppspeech::RawDataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source( std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
new ppspeech::RawAudioCache()); new ppspeech::RawAudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt; ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20; opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10; opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt; LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor( LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram( std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::move(base_feature_extractor)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn( std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path, new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram))); std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
float streaming_chunk = 0.36;
int sample_rate = 16000; int sample_rate = 16000;
float streaming_chunk = 0.36;
int chunk_sample_size = streaming_chunk * sample_rate; int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key(); std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
int tot_samples = waveform.Dim(); int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0; int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats; std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0; int feature_rows = 0;
@ -209,6 +224,7 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk(i) = waveform(sample_offset + i);
} }
kaldi::Vector<BaseFloat> features; kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk); feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {

@ -25,6 +25,7 @@ feat_wspecifier=./feats.ark
cmvn=./cmvn.ark cmvn=./cmvn.ark
# 3. run feat # 3. run feat
export GLOG_logtostderr=1
linear_spectrogram_main \ linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \ --wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \

@ -28,10 +28,10 @@ class FeatureCache : public FeatureExtractorInterface {
// Feed feats or waves // Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// feats dim = num_frames * feature_dim // feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feature cache only cache feature which from base extractor // feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); } virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { virtual void SetFinished() {

@ -68,9 +68,10 @@ class RawDataCache : public FeatureExtractorInterface {
data_.Resize(0); data_.Resize(0);
return true; return true;
} }
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { finished_ = true; } virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
virtual size_t Dim() const { return dim_; }
void SetDim(int32 dim) { dim_ = dim; } void SetDim(int32 dim) { dim_ = dim; }
virtual void Reset() { finished_ = true; } virtual void Reset() { finished_ = true; }

@ -127,7 +127,7 @@ decoders_module = [
setup( setup(
name='paddlespeech_ctcdecoders', name='paddlespeech_ctcdecoders',
version='0.1.1', version='0.2.0',
description="CTC decoders in paddlespeech", description="CTC decoders in paddlespeech",
author="PaddlePaddle Speech and Language Team", author="PaddlePaddle Speech and Language Team",
author_email="paddlesl@baidu.com", author_email="paddlesl@baidu.com",

Loading…
Cancel
Save