Merge pull request #2502 from zh794390558/u2pp_export
[s2t] streaming conformer u2 and u2pp jit exportpull/2511/head
commit
c9b0c96b7b
@ -0,0 +1,101 @@
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
cmvn_file:
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 512 # dimension of attention
|
||||
attention_heads: 8
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: swish
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
causal: true
|
||||
use_dynamic_chunk: true
|
||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
||||
use_dynamic_left_chunk: false
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
reverse_weight: 0.0 # unidecoder
|
||||
length_normalized_loss: false
|
||||
init_type: 'kaiming_uniform'
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/train_l/data.list
|
||||
dev_manifest: data/dev/data.list
|
||||
test_manifest: data/test_meeting/data.list
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
use_streaming_data: True
|
||||
unit_type: 'char'
|
||||
vocab_filepath: data/lang_char/vocab.txt
|
||||
preprocess_config: conf/preprocess.yaml
|
||||
spm_model_prefix: ''
|
||||
feat_dim: 80
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
|
||||
batch_size: 32
|
||||
do_filter: True
|
||||
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
|
||||
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
|
||||
minlen_in: 10
|
||||
minlen_out: 0
|
||||
minibatches: 0 # for debug
|
||||
batch_count: auto
|
||||
batch_bins: 0
|
||||
batch_frames_in: 0
|
||||
batch_frames_out: 0
|
||||
batch_frames_inout: 0
|
||||
num_workers: 0
|
||||
subsampling_factor: 1
|
||||
num_encs: 1
|
||||
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 26
|
||||
accum_grad: 32
|
||||
global_grad_clip: 5.0
|
||||
dist_sampler: True
|
||||
log_interval: 1
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1.0e-6
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 5000
|
||||
lr_decay: 1.0
|
@ -0,0 +1,100 @@
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
cmvn_file:
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 512 # dimension of attention
|
||||
attention_heads: 8
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: swish
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
causal: true
|
||||
use_dynamic_chunk: true
|
||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
||||
use_dynamic_left_chunk: false
|
||||
# decoder related
|
||||
decoder: bitransformer
|
||||
decoder_conf:
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 3 # the number of encoder blocks
|
||||
r_num_blocks: 3 #only for bitransformer
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.1
|
||||
src_attention_dropout_rate: 0.1
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
reverse_weight: 0.3 # only for bitransformer decoder
|
||||
init_type: 'kaiming_uniform' # !Warning: need to convergence
|
||||
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/train_l/data.list
|
||||
dev_manifest: data/dev/data.list
|
||||
test_manifest: data/test_meeting/data.list
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
use_stream_data: True
|
||||
vocab_filepath: data/lang_char/vocab.txt
|
||||
unit_type: 'char'
|
||||
preprocess_config: conf/preprocess.yaml
|
||||
spm_model_prefix: ''
|
||||
feat_dim: 80
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
|
||||
batch_size: 32
|
||||
do_filter: True
|
||||
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
|
||||
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
|
||||
minlen_in: 10
|
||||
minlen_out: 0
|
||||
minibatches: 0 # for debug
|
||||
batch_count: auto
|
||||
batch_bins: 0
|
||||
batch_frames_in: 0
|
||||
batch_frames_out: 0
|
||||
batch_frames_inout: 0
|
||||
num_workers: 0
|
||||
subsampling_factor: 1
|
||||
num_encs: 1
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 150
|
||||
accum_grad: 8
|
||||
global_grad_clip: 5.0
|
||||
dist_sampler: False
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.002
|
||||
weight_decay: 1.0e-6
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
@ -0,0 +1,12 @@
|
||||
beam_size: 10
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: True # simulate streaming inference. Defaults to False.
|
||||
decode_batch_size: 128
|
||||
error_rate_type: cer
|
@ -1,11 +1,12 @@
|
||||
decode_batch_size: 128
|
||||
error_rate_type: cer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
beam_size: 10
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||
decode_batch_size: 128
|
||||
error_rate_type: cer
|
||||
|
@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
audio_file=$4
|
||||
|
||||
mkdir -p data
|
||||
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f ${audio_file} ]; then
|
||||
echo "Plase input the right audio_file path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
for type in attention_rescoring; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/quant.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for U2 model."""
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
import soundfile
|
||||
from paddleslim import PTQ
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.audio.transform.transformation import Transformation
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.models.u2 import U2Model
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class U2Infer():
|
||||
def __init__(self, config, args):
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.audio_file = args.audio_file
|
||||
|
||||
self.preprocess_conf = config.preprocess_config
|
||||
self.preprocess_args = {"train": False}
|
||||
self.preprocessing = Transformation(self.preprocess_conf)
|
||||
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
|
||||
self.text_feature = TextFeaturizer(
|
||||
unit_type=config.unit_type,
|
||||
vocab=config.vocab_filepath,
|
||||
spm_model_prefix=config.spm_model_prefix)
|
||||
|
||||
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
|
||||
|
||||
# model
|
||||
model_conf = config
|
||||
with UpdateConfig(model_conf):
|
||||
model_conf.input_dim = config.feat_dim
|
||||
model_conf.output_dim = self.text_feature.vocab_size
|
||||
model = U2Model.from_config(model_conf)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.ptq = PTQ()
|
||||
self.model = self.ptq.quantize(model)
|
||||
|
||||
# load model
|
||||
params_path = self.args.checkpoint_path + ".pdparams"
|
||||
model_dict = paddle.load(params_path)
|
||||
self.model.set_state_dict(model_dict)
|
||||
|
||||
def run(self):
|
||||
check(args.audio_file)
|
||||
|
||||
with paddle.no_grad():
|
||||
# read
|
||||
audio, sample_rate = soundfile.read(
|
||||
self.audio_file, dtype="int16", always_2d=True)
|
||||
audio = audio[:, 0]
|
||||
logger.info(f"audio shape: {audio.shape}")
|
||||
|
||||
# fbank
|
||||
feat = self.preprocessing(audio, **self.preprocess_args)
|
||||
logger.info(f"feat shape: {feat.shape}")
|
||||
|
||||
ilen = paddle.to_tensor(feat.shape[0])
|
||||
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
|
||||
decode_config = self.config.decode
|
||||
logger.info(f"decode cfg: {decode_config}")
|
||||
result_transcripts = self.model.decode(
|
||||
xs,
|
||||
ilen,
|
||||
text_feature=self.text_feature,
|
||||
decoding_method=decode_config.decoding_method,
|
||||
beam_size=decode_config.beam_size,
|
||||
ctc_weight=decode_config.ctc_weight,
|
||||
decoding_chunk_size=decode_config.decoding_chunk_size,
|
||||
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
|
||||
simulate_streaming=decode_config.simulate_streaming,
|
||||
reverse_weight=decode_config.reverse_weight)
|
||||
rsl = result_transcripts[0][0]
|
||||
utt = Path(self.audio_file).name
|
||||
logger.info(f"hyp: {utt} {rsl}")
|
||||
# print(self.model)
|
||||
# print(self.model.forward_encoder_chunk)
|
||||
|
||||
logger.info("-------------start quant ----------------------")
|
||||
batch_size = 1
|
||||
feat_dim = 80
|
||||
model_size = 512
|
||||
num_left_chunks = -1
|
||||
reverse_weight = 0.3
|
||||
logger.info(
|
||||
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}"
|
||||
)
|
||||
|
||||
# ######################## self.model.forward_encoder_chunk ############
|
||||
# input_spec = [
|
||||
# # (T,), int16
|
||||
# paddle.static.InputSpec(shape=[None], dtype='int16'),
|
||||
# ]
|
||||
# self.model.forward_feature = paddle.jit.to_static(
|
||||
# self.model.forward_feature, input_spec=input_spec)
|
||||
|
||||
######################### self.model.forward_encoder_chunk ############
|
||||
input_spec = [
|
||||
# xs, (B, T, D)
|
||||
paddle.static.InputSpec(
|
||||
shape=[batch_size, None, feat_dim], dtype='float32'),
|
||||
# offset, int, but need be tensor
|
||||
paddle.static.InputSpec(shape=[1], dtype='int32'),
|
||||
# required_cache_size, int
|
||||
num_left_chunks,
|
||||
# att_cache
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, None, None, None], dtype='float32'),
|
||||
# cnn_cache
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, None, None, None], dtype='float32')
|
||||
]
|
||||
self.model.forward_encoder_chunk = paddle.jit.to_static(
|
||||
self.model.forward_encoder_chunk, input_spec=input_spec)
|
||||
|
||||
######################### self.model.ctc_activation ########################
|
||||
input_spec = [
|
||||
# encoder_out, (B,T,D)
|
||||
paddle.static.InputSpec(
|
||||
shape=[batch_size, None, model_size], dtype='float32')
|
||||
]
|
||||
self.model.ctc_activation = paddle.jit.to_static(
|
||||
self.model.ctc_activation, input_spec=input_spec)
|
||||
|
||||
######################### self.model.forward_attention_decoder ########################
|
||||
input_spec = [
|
||||
# hyps, (B, U)
|
||||
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
|
||||
# hyps_lens, (B,)
|
||||
paddle.static.InputSpec(shape=[None], dtype='int64'),
|
||||
# encoder_out, (B,T,D)
|
||||
paddle.static.InputSpec(
|
||||
shape=[batch_size, None, model_size], dtype='float32'),
|
||||
reverse_weight
|
||||
]
|
||||
self.model.forward_attention_decoder = paddle.jit.to_static(
|
||||
self.model.forward_attention_decoder, input_spec=input_spec)
|
||||
################################################################################
|
||||
|
||||
# jit save
|
||||
logger.info(f"export save: {self.args.export_path}")
|
||||
config = {
|
||||
'is_static': True,
|
||||
'combine_params': True,
|
||||
'skip_forward': True
|
||||
}
|
||||
self.ptq.save_quantized_model(self.model, self.args.export_path)
|
||||
# paddle.jit.save(
|
||||
# self.model,
|
||||
# self.args.export_path,
|
||||
# combine_params=True,
|
||||
# skip_forward=True)
|
||||
|
||||
|
||||
def check(audio_file):
|
||||
if not os.path.isfile(audio_file):
|
||||
print("Please input the right audio file path")
|
||||
sys.exit(-1)
|
||||
|
||||
logger.info("checking the audio file format......")
|
||||
try:
|
||||
sig, sample_rate = soundfile.read(audio_file)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error(
|
||||
"can not open the wav file, please check the audio file format")
|
||||
sys.exit(-1)
|
||||
logger.info("The sample rate is %d" % sample_rate)
|
||||
assert (sample_rate == 16000)
|
||||
logger.info("The audio file format is right")
|
||||
|
||||
|
||||
def main(config, args):
|
||||
U2Infer(config, args).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
parser.add_argument(
|
||||
"--audio_file", type=str, help="path of the input audio file")
|
||||
parser.add_argument(
|
||||
"--export_path",
|
||||
type=str,
|
||||
default='export',
|
||||
help="path of the input audio file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = CfgNode(new_allowed=True)
|
||||
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.decode_cfg:
|
||||
decode_confs = CfgNode(new_allowed=True)
|
||||
decode_confs.merge_from_file(args.decode_cfg)
|
||||
config.decode = decode_confs
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
main(config, args)
|
@ -0,0 +1,72 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.audio.compliance import kaldi
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['KaldiFbank']
|
||||
|
||||
|
||||
class KaldiFbank(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fs=16000,
|
||||
n_mels=80,
|
||||
n_shift=160, # unit:sample, 10ms
|
||||
win_length=400, # unit:sample, 25ms
|
||||
energy_floor=0.0,
|
||||
dither=0.0):
|
||||
"""
|
||||
Args:
|
||||
fs (int): sample rate of the audio
|
||||
n_mels (int): number of mel filter banks
|
||||
n_shift (int): number of points in a frame shift
|
||||
win_length (int): number of points in a frame windows
|
||||
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
|
||||
dither (float): Dithering constant. Default 0.0
|
||||
"""
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
num_point_ms = fs / 1000
|
||||
self.n_frame_length = win_length / num_point_ms
|
||||
self.n_frame_shift = n_shift / num_point_ms
|
||||
self.energy_floor = energy_floor
|
||||
self.dither = dither
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, "
|
||||
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
|
||||
"dither={dither}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_frame_shift=self.n_frame_shift,
|
||||
n_frame_length=self.n_frame_length,
|
||||
dither=self.dither, ))
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (paddle.Tensor): shape (Ti).
|
||||
Not support: [Time, Channel] and Batch mode.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (T, D)
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
|
||||
feat = kaldi.fbank(
|
||||
x.unsqueeze(0), # append channel dim, (C, Ti)
|
||||
n_mels=self.n_mels,
|
||||
frame_length=self.n_frame_length,
|
||||
frame_shift=self.n_frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=self.energy_floor,
|
||||
sr=self.fs)
|
||||
|
||||
assert feat.ndim == 2 # (T,D)
|
||||
return feat
|
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
import paddlespeech.s2t # noqa: F401
|
||||
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
||||
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
||||
|
||||
# from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
||||
|
||||
|
||||
def reverse_pad_list(ys_pad: paddle.Tensor,
|
||||
ys_lens: paddle.Tensor,
|
||||
pad_value: float=-1.0) -> paddle.Tensor:
|
||||
"""Reverse padding for the list of tensors.
|
||||
Args:
|
||||
ys_pad (tensor): The padded tensor (B, Tokenmax).
|
||||
ys_lens (tensor): The lens of token seqs (B)
|
||||
pad_value (int): Value for padding.
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tokenmax).
|
||||
Examples:
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[4, 3, 2, 1],
|
||||
[7, 6, 5, 0],
|
||||
[9, 8, 0, 0]])
|
||||
"""
|
||||
r_ys_pad = pad_sequence([(paddle.flip(y[:i], [0]))
|
||||
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
|
||||
return r_ys_pad
|
||||
|
||||
|
||||
def naive_reverse_pad_list_with_sos_eos(r_hyps,
|
||||
r_hyps_lens,
|
||||
sos=5000,
|
||||
eos=5000,
|
||||
ignore_id=-1):
|
||||
r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(ignore_id))
|
||||
r_hyps, _ = add_sos_eos(r_hyps, sos, eos, ignore_id)
|
||||
return r_hyps
|
||||
|
||||
|
||||
def reverse_pad_list_with_sos_eos(r_hyps,
|
||||
r_hyps_lens,
|
||||
sos=5000,
|
||||
eos=5000,
|
||||
ignore_id=-1):
|
||||
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
|
||||
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
|
||||
max_len = paddle.max(r_hyps_lens)
|
||||
index_range = paddle.arange(0, max_len, 1)
|
||||
seq_len_expand = r_hyps_lens.unsqueeze(1)
|
||||
seq_mask = seq_len_expand > index_range # (beam, max_len)
|
||||
|
||||
index = (seq_len_expand - 1) - index_range # (beam, max_len)
|
||||
# >>> index
|
||||
# >>> tensor([[ 2, 1, 0],
|
||||
# >>> [ 2, 1, 0],
|
||||
# >>> [ 0, -1, -2]])
|
||||
index = index * seq_mask
|
||||
|
||||
# >>> index
|
||||
# >>> tensor([[2, 1, 0],
|
||||
# >>> [2, 1, 0],
|
||||
# >>> [0, 0, 0]])
|
||||
def paddle_gather(x, dim, index):
|
||||
index_shape = index.shape
|
||||
index_flatten = index.flatten()
|
||||
if dim < 0:
|
||||
dim = len(x.shape) + dim
|
||||
nd_index = []
|
||||
for k in range(len(x.shape)):
|
||||
if k == dim:
|
||||
nd_index.append(index_flatten)
|
||||
else:
|
||||
reshape_shape = [1] * len(x.shape)
|
||||
reshape_shape[k] = x.shape[k]
|
||||
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
|
||||
x_arange = x_arange.reshape(reshape_shape)
|
||||
dim_index = paddle.expand(x_arange, index_shape).flatten()
|
||||
nd_index.append(dim_index)
|
||||
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
|
||||
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
|
||||
return paddle_out
|
||||
|
||||
r_hyps = paddle_gather(r_hyps, 1, index)
|
||||
# >>> r_hyps
|
||||
# >>> tensor([[3, 2, 1],
|
||||
# >>> [4, 8, 9],
|
||||
# >>> [2, 2, 2]])
|
||||
r_hyps = paddle.where(seq_mask, r_hyps, eos)
|
||||
# >>> r_hyps
|
||||
# >>> tensor([[3, 2, 1],
|
||||
# >>> [4, 8, 9],
|
||||
# >>> [2, eos, eos]])
|
||||
B = r_hyps.shape[0]
|
||||
_sos = paddle.ones([B, 1], dtype=r_hyps.dtype) * sos
|
||||
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
|
||||
r_hyps = paddle.concat([_sos, r_hyps], axis=1)
|
||||
# >>> r_hyps
|
||||
# >>> tensor([[sos, 3, 2, 1],
|
||||
# >>> [sos, 4, 8, 9],
|
||||
# >>> [sos, 2, eos, eos]])
|
||||
return r_hyps
|
||||
|
||||
|
||||
class TestU2Model(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.set_device('cpu')
|
||||
|
||||
self.sos = 5000
|
||||
self.eos = 5000
|
||||
self.ignore_id = -1
|
||||
self.reverse_hyps = paddle.to_tensor([[4, 3, 2, 1, -1],
|
||||
[5, 4, 3, 2, 1]])
|
||||
self.reverse_hyps_sos_eos = paddle.to_tensor(
|
||||
[[self.sos, 4, 3, 2, 1, self.eos], [self.sos, 5, 4, 3, 2, 1]])
|
||||
|
||||
self.hyps = paddle.to_tensor([[1, 2, 3, 4, -1], [1, 2, 3, 4, 5]])
|
||||
|
||||
self.hyps_lens = paddle.to_tensor([4, 5], paddle.int32)
|
||||
|
||||
def test_reverse_pad_list(self):
|
||||
r_hyps = reverse_pad_list(self.hyps, self.hyps_lens)
|
||||
self.assertSequenceEqual(r_hyps.tolist(), self.reverse_hyps.tolist())
|
||||
|
||||
def test_naive_reverse_pad_list_with_sos_eos(self):
|
||||
r_hyps_sos_eos = naive_reverse_pad_list_with_sos_eos(self.hyps,
|
||||
self.hyps_lens)
|
||||
self.assertSequenceEqual(r_hyps_sos_eos.tolist(),
|
||||
self.reverse_hyps_sos_eos.tolist())
|
||||
|
||||
def test_static_reverse_pad_list_with_sos_eos(self):
|
||||
r_hyps_sos_eos_static = reverse_pad_list_with_sos_eos(self.hyps,
|
||||
self.hyps_lens)
|
||||
self.assertSequenceEqual(r_hyps_sos_eos_static.tolist(),
|
||||
self.reverse_hyps_sos_eos.tolist())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue