Merge pull request #2502 from zh794390558/u2pp_export

[s2t]  streaming conformer u2 and u2pp jit export
pull/2511/head
Zth9730 2 years ago committed by GitHub
commit c9b0c96b7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,3 +12,36 @@ show model.tar.gz
``` ```
tar tf model.tar.gz tar tf model.tar.gz
``` ```
other way is:
```bash
tar cvzf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz model.yaml conf/tuning/ conf/chunk_conformer.yaml conf/preprocess.yaml data/mean_std.json exp/chunk_conformer/checkpoints/
```
## Export Static Model
>> Need Paddle >= 2.4
>> `data/test_meeting/data.list`
>> {"input": [{"name": "input1", "shape": [3.2230625, 80], "feat": "/home/PaddleSpeech/dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0163.wav", "filetype": "sound"}], "output": [{"name": "target1", "shape": [9, 5538], "text": "\u697c\u5e02\u8c03\u63a7\u5c06\u53bb\u5411\u4f55\u65b9", "token": "\u697c \u5e02 \u8c03 \u63a7 \u5c06 \u53bb \u5411 \u4f55 \u65b9", "tokenid": "1891 1121 3502 1543 1018 477 528 163 1657"}], "utt": "BAC009S0764W0163", "utt2spk": "S0764"}
>> Test Wav:
>> wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
### U2 chunk conformer
>> UiDecoder
>> Make sure `reverse_weight` in config is `0.0`
>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz
```
tar zxvf asr1_chunk_conformer_u2_wenetspeech_ckpt_1.1.0.model.tar.gz
./local/export.sh conf/chunk_conformer.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji
```
### U2++ chunk conformer
>> BiDecoder
>> https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.0.model.tar.gz
>> Make sure `reverse_weight` in config is not `0.0`
```
./local/export.sh conf/chunk_conformer_u2pp.yaml exp/chunk_conformer/checkpoints/avg_10 ./export.ji
```

@ -0,0 +1,101 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: swish
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.0 # unidecoder
length_normalized_loss: false
init_type: 'kaiming_uniform'
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list
###########################################
# Dataloader #
###########################################
use_streaming_data: True
unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
preprocess_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
do_filter: True
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
minlen_in: 10
minlen_out: 0
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 26
accum_grad: 32
global_grad_clip: 5.0
dist_sampler: True
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 5000
lr_decay: 1.0

@ -0,0 +1,100 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: swish
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 3 # the number of encoder blocks
r_num_blocks: 3 #only for bitransformer
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3 # only for bitransformer decoder
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list
###########################################
# Dataloader #
###########################################
use_stream_data: True
vocab_filepath: data/lang_char/vocab.txt
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
do_filter: True
maxlen_in: 1200 # if do_filter == False && input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 100 # if do_filter == False && output length > maxlen-out, batchsize is automatically reduced
minlen_in: 10
minlen_out: 0
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 150
accum_grad: 8
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: 0.1 dither: 1.0
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -0,0 +1,12 @@
beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: True # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer

@ -1,11 +1,12 @@
decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
beam_size: 10 beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False. simulate_streaming: False # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer

@ -12,9 +12,14 @@ config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
# export can not using StreamdataDataloader, set use_stream_dta False
# u2: reverse_weight should be 0.0
# u2pp: reverse_weight should be same with config file. e.g. 0.3
python3 -u ${BIN_DIR}/export.py \ python3 -u ${BIN_DIR}/export.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--opts use_stream_data False \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path} --export_path ${jit_model_export_path}

@ -0,0 +1,59 @@
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi
if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
exit 1
fi
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/quant.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0

@ -74,16 +74,16 @@ def _feature_window_function(
window_size: int, window_size: int,
blackman_coeff: float, blackman_coeff: float,
dtype: int, ) -> Tensor: dtype: int, ) -> Tensor:
if window_type == HANNING: if window_type == "hann":
return get_window('hann', window_size, fftbins=False, dtype=dtype) return get_window('hann', window_size, fftbins=False, dtype=dtype)
elif window_type == HAMMING: elif window_type == "hamming":
return get_window('hamming', window_size, fftbins=False, dtype=dtype) return get_window('hamming', window_size, fftbins=False, dtype=dtype)
elif window_type == POVEY: elif window_type == "povey":
return get_window( return get_window(
'hann', window_size, fftbins=False, dtype=dtype).pow(0.85) 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR: elif window_type == "rect":
return paddle.ones([window_size], dtype=dtype) return paddle.ones([window_size], dtype=dtype)
elif window_type == BLACKMAN: elif window_type == "blackman":
a = 2 * math.pi / (window_size - 1) a = 2 * math.pi / (window_size - 1)
window_function = paddle.arange(window_size, dtype=dtype) window_function = paddle.arange(window_size, dtype=dtype)
return (blackman_coeff - 0.5 * paddle.cos(a * window_function) + return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
@ -216,7 +216,7 @@ def spectrogram(waveform: Tensor,
sr: int=16000, sr: int=16000,
snip_edges: bool=True, snip_edges: bool=True,
subtract_mean: bool=False, subtract_mean: bool=False,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's. """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
Args: Args:
@ -236,7 +236,7 @@ def spectrogram(waveform: Tensor,
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
Returns: Returns:
Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int,
('Bad values in options: vtln-low {} and vtln-high {}, versus ' ('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = paddle.arange(num_bins).unsqueeze(1) bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1)
# left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
# center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
# right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) center_mel = left_mel + mel_freq_delta
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) right_mel = center_mel + mel_freq_delta
if vtln_warp_factor != 1.0: if vtln_warp_factor != 1.0:
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
@ -373,7 +376,8 @@ def _get_mel_banks(num_bins: int,
center_freqs = _inverse_mel_scale(center_mel) # (num_bins) center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
# (1, num_fft_bins) # (1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) mel = _mel_scale(fft_bin_width * paddle.arange(
num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
# (num_bins, num_fft_bins) # (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel) up_slope = (mel - left_mel) / (center_mel - left_mel)
@ -418,11 +422,11 @@ def fbank(waveform: Tensor,
vtln_high: float=-500.0, vtln_high: float=-500.0,
vtln_low: float=100.0, vtln_low: float=100.0,
vtln_warp: float=1.0, vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's. """Compute and return filter banks from a waveform. The output is identical to Kaldi's.
Args: Args:
waveform (Tensor): A waveform tensor with shape `(C, T)`. waveform (Tensor): A waveform tensor with shape `(C, T)`. `C` is in the range [0,1].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1. channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0. dither (float, optional): Dithering constant . Defaults to 0.0.
@ -448,7 +452,7 @@ def fbank(waveform: Tensor,
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0. vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0. vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0. vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
Returns: Returns:
Tensor: A filter banks tensor with shape `(m, n_mels)`. Tensor: A filter banks tensor with shape `(m, n_mels)`.
@ -472,7 +476,8 @@ def fbank(waveform: Tensor,
# (n_mels, padded_window_size // 2) # (n_mels, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq, mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
high_freq, vtln_low, vtln_high, vtln_warp) high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.astype(dtype) # mel_energies = mel_energies.astype(dtype)
assert mel_energies.dtype == dtype
# (n_mels, padded_window_size // 2 + 1) # (n_mels, padded_window_size // 2 + 1)
mel_energies = paddle.nn.functional.pad( mel_energies = paddle.nn.functional.pad(
@ -537,7 +542,7 @@ def mfcc(waveform: Tensor,
vtln_high: float=-500.0, vtln_high: float=-500.0,
vtln_low: float=100.0, vtln_low: float=100.0,
vtln_warp: float=1.0, vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is """Compute and return mel frequency cepstral coefficients from a waveform. The output is
identical to Kaldi's. identical to Kaldi's.

@ -152,8 +152,8 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0]) # return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])
B = ys_pad.shape[0] B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos _eos = paddle.full([B, 1], eos, dtype=ys_pad.dtype)
ys_in = paddle.cat([_sos, ys_pad], dim=1) ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id) mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos) ys_in = ys_in.masked_fill(mask_pad, eos)
@ -279,8 +279,8 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> tensor([[3, 2, 1], # >>> tensor([[3, 2, 1],
# >>> [4, 8, 9], # >>> [4, 8, 9],
# >>> [2, 2, 2]]) # >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype) _eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos) r_hyps = paddle.where(seq_mask, r_hyps, _eos)
# >>> r_hyps # >>> r_hyps
# >>> tensor([[3, 2, 1], # >>> tensor([[3, 2, 1],
# >>> [4, 8, 9], # >>> [4, 8, 9],

@ -22,7 +22,6 @@ from paddle.nn import functional as F
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
#TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
########### hack logging ############# ########### hack logging #############
@ -167,13 +166,17 @@ def broadcast_shape(shp1, shp2):
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
# will be nan when value is `inf`.
# mask = mask.astype(xs.dtype)
# return xs * (1.0 - mask) + mask * value
bshape = broadcast_shape(xs.shape, mask.shape) bshape = broadcast_shape(xs.shape, mask.shape)
mask.stop_gradient = True mask.stop_gradient = True
tmp = paddle.ones(shape=[len(bshape)], dtype='int32') # tmp = paddle.ones(shape=[len(bshape)], dtype='int32')
for index in range(len(bshape)): # for index in range(len(bshape)):
tmp[index] = bshape[index] # tmp[index] = bshape[index]
mask = mask.broadcast_to(tmp) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.full_like(xs, fill_value=value)
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs

@ -0,0 +1,224 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import os
import sys
from pathlib import Path
import paddle
import soundfile
from paddleslim import PTQ
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class U2Infer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
spm_model_prefix=config.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
model_conf = config
with UpdateConfig(model_conf):
model_conf.input_dim = config.feat_dim
model_conf.output_dim = self.text_feature.vocab_size
model = U2Model.from_config(model_conf)
self.model = model
self.model.eval()
self.ptq = PTQ()
self.model = self.ptq.quantize(model)
# load model
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode(
xs,
ilen,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}")
# print(self.model)
# print(self.model.forward_encoder_chunk)
logger.info("-------------start quant ----------------------")
batch_size = 1
feat_dim = 80
model_size = 512
num_left_chunks = -1
reverse_weight = 0.3
logger.info(
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}"
)
# ######################## self.model.forward_encoder_chunk ############
# input_spec = [
# # (T,), int16
# paddle.static.InputSpec(shape=[None], dtype='int16'),
# ]
# self.model.forward_feature = paddle.jit.to_static(
# self.model.forward_feature, input_spec=input_spec)
######################### self.model.forward_encoder_chunk ############
input_spec = [
# xs, (B, T, D)
paddle.static.InputSpec(
shape=[batch_size, None, feat_dim], dtype='float32'),
# offset, int, but need be tensor
paddle.static.InputSpec(shape=[1], dtype='int32'),
# required_cache_size, int
num_left_chunks,
# att_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32'),
# cnn_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
]
self.model.forward_encoder_chunk = paddle.jit.to_static(
self.model.forward_encoder_chunk, input_spec=input_spec)
######################### self.model.ctc_activation ########################
input_spec = [
# encoder_out, (B,T,D)
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32')
]
self.model.ctc_activation = paddle.jit.to_static(
self.model.ctc_activation, input_spec=input_spec)
######################### self.model.forward_attention_decoder ########################
input_spec = [
# hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# hyps_lens, (B,)
paddle.static.InputSpec(shape=[None], dtype='int64'),
# encoder_out, (B,T,D)
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32'),
reverse_weight
]
self.model.forward_attention_decoder = paddle.jit.to_static(
self.model.forward_attention_decoder, input_spec=input_spec)
################################################################################
# jit save
logger.info(f"export save: {self.args.export_path}")
config = {
'is_static': True,
'combine_params': True,
'skip_forward': True
}
self.ptq.save_quantized_model(self.model, self.args.export_path)
# paddle.jit.save(
# self.model,
# self.args.export_path,
# combine_params=True,
# skip_forward=True)
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args):
U2Infer(config, args).run()
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--export_path",
type=str,
default='export',
help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args)

@ -20,8 +20,6 @@ from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args): def main_sp(config, args):
exp = Tester(config, args) exp = Tester(config, args)

@ -68,7 +68,6 @@ class U2Infer():
# read # read
audio, sample_rate = soundfile.read( audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True) self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0] audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
@ -77,8 +76,9 @@ class U2Infer():
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
ilen, ilen,
@ -88,7 +88,8 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming) simulate_streaming=decode_config.simulate_streaming,
reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -350,7 +350,8 @@ class U2Tester(U2Trainer):
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming) simulate_streaming=decode_config.simulate_streaming,
reverse_weight=decode_config.reverse_weight)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip( for utt, target, result, rec_tids in zip(
@ -462,20 +463,120 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.clone(), self.config.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
batch_size = 1
feat_dim = self.test_loader.feat_dim feat_dim = self.test_loader.feat_dim
input_spec = [ model_size = self.config.encoder_conf.output_size
paddle.static.InputSpec(shape=[1, None, feat_dim], num_left_chunks = -1
dtype='float32'), # audio, [B,T,D] logger.info(
paddle.static.InputSpec(shape=[1], f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}"
dtype='int64'), # audio_length, [B] )
]
return infer_model, input_spec return infer_model, (batch_size, feat_dim, model_size, num_left_chunks)
@paddle.no_grad() @paddle.no_grad()
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
infer_model.eval() infer_model.eval()
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) paddle.set_device('cpu')
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) assert isinstance(input_spec, (list, tuple)), type(input_spec)
batch_size, feat_dim, model_size, num_left_chunks = input_spec
######################## infer_model.forward_encoder_chunk ############
input_spec = [
# (T,), int16
paddle.static.InputSpec(shape=[None], dtype='int16'),
]
infer_model.forward_feature = paddle.jit.to_static(
infer_model.forward_feature, input_spec=input_spec)
######################### infer_model.forward_encoder_chunk ############
input_spec = [
# xs, (B, T, D)
paddle.static.InputSpec(
shape=[batch_size, None, feat_dim], dtype='float32'),
# offset, int, but need be tensor
paddle.static.InputSpec(shape=[1], dtype='int32'),
# required_cache_size, int
num_left_chunks,
# att_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32'),
# cnn_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
]
infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec)
######################### infer_model.ctc_activation ########################
input_spec = [
# encoder_out, (B,T,D)
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
######################### infer_model.forward_attention_decoder ########################
reverse_weight = 0.3
input_spec = [
# hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# hyps_lens, (B,)
paddle.static.InputSpec(shape=[None], dtype='int64'),
# encoder_out, (B,T,D)
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32'),
reverse_weight
]
infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec)
# jit save
logger.info(f"export save: {self.args.export_path}")
paddle.jit.save(
infer_model,
self.args.export_path,
combine_params=True,
skip_forward=True)
# test dy2static
def flatten(out):
if isinstance(out, paddle.Tensor):
return [out]
flatten_out = []
for var in out:
if isinstance(var, (list, tuple)):
flatten_out.extend(flatten(var))
else:
flatten_out.append(var)
return flatten_out
# forward_encoder_chunk dygraph
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32')
required_cache_size = num_left_chunks
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(
xs1, offset, required_cache_size, att_cache, cnn_cache)
# load static model
from paddle.jit.layer import Layer
layer = Layer()
logger.info(f"load export model: {self.args.export_path}")
layer.load(self.args.export_path, paddle.CPUPlace())
# forward_encoder_chunk static
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32')
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
func = getattr(layer, 'forward_encoder_chunk')
xs_s, att_cache_s, cnn_cache_s = func(xs1, offset, att_cache, cnn_cache)
np.testing.assert_allclose(xs_d, xs_s, atol=1e-5)
np.testing.assert_allclose(att_cache_d, att_cache_s, atol=1e-4)
np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
# logger.info(f"forward_encoder_chunk output: {xs_s}")

@ -20,8 +20,6 @@ from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args): def main_sp(config, args):
exp = Tester(config, args) exp = Tester(config, args)

@ -124,17 +124,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch # 2a. Attention-decoder branch
loss_att = None loss_att = None
if self.ctc_weight != 1.0: if self.ctc_weight != 1.0:
start = time.time() start = time.time()
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths) text, text_lengths,
self.reverse_weight)
decoder_time = time.time() - start decoder_time = time.time() - start
#logger.debug(f"decoder time: {decoder_time}") #logger.debug(f"decoder time: {decoder_time}")
@ -155,12 +153,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
return loss, loss_att, loss_ctc return loss, loss_att, loss_ctc
def _calc_att_loss( def _calc_att_loss(self,
self, encoder_out: paddle.Tensor,
encoder_out: paddle.Tensor, encoder_mask: paddle.Tensor,
encoder_mask: paddle.Tensor, ys_pad: paddle.Tensor,
ys_pad: paddle.Tensor, ys_pad_lens: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: reverse_weight: float) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss. """Calc attention loss.
Args: Args:
@ -168,6 +166,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
encoder_mask (paddle.Tensor): [B, 1, Tmax] encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax] ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B] ys_pad_lens (paddle.Tensor): [B]
reverse_weight (float): reverse decoder weight.
Returns: Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
@ -182,15 +181,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
# 1. Forward decoder # 1. Forward decoder
decoder_out, r_decoder_out, _ = self.decoder( decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
self.reverse_weight) reverse_weight)
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = paddle.to_tensor(0.0) r_loss_att = paddle.to_tensor(0.0)
if self.reverse_weight > 0.0: if reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - self.reverse_weight loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
) + r_loss_att * self.reverse_weight
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.view(-1, self.vocab_size),
ys_out_pad, ys_out_pad,
@ -291,8 +289,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
@ -378,9 +375,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.shape[1] maxlen = encoder_out.shape[1]
# (TODO Hui Zhang): bool no support reduce_sum encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
@ -514,7 +509,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0, ctc_weight: float=0.0,
simulate_streaming: bool=False) -> List[int]: simulate_streaming: bool=False,
reverse_weight: float=0.0) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search """ Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out attention decoder with corresponding encoder out
@ -529,12 +525,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
0: used for training, it's prohibited here 0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a simulate_streaming (bool): whether do encoder forward in a
streaming fashion streaming fashion
reverse_weight (float): reverse deocder weight.
Returns: Returns:
List[int]: Attention rescoring result List[int]: Attention rescoring result
""" """
assert speech.shape[0] == speech_lengths.shape[0] assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0 assert decoding_chunk_size != 0
if self.reverse_weight > 0.0: if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0 # decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder') assert hasattr(self.decoder, 'right_decoder')
device = speech.place device = speech.place
@ -558,28 +555,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content, place=device, dtype=paddle.long) hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1) logger.debug(
encoder_mask = paddle.ones( f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos,
self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
self.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) # (beam_size, max_hyps_len, vocab_size)
decoder_out = decoder_out.numpy() decoder_out, r_decoder_out = self.forward_attention_decoder(
hyps_pad, hyps_lens, encoder_out, reverse_weight)
decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder. # conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy() r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
@ -592,46 +583,68 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos] score += decoder_out[i][len(hyp[0])][self.eos]
if self.reverse_weight > 0:
logger.debug(
f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
)
if reverse_weight > 0:
r_score = 0.0 r_score = 0.0
for j, w in enumerate(hyp[0]): for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos] r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - self.reverse_weight
) + r_score * self.reverse_weight logger.debug(
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
)
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:
best_score = score best_score = score
best_index = i best_index = i
logger.debug(f"result: {hyps[best_index]}")
return hyps[best_index][0] return hyps[best_index][0]
#@jit.to_static @jit.to_static(property=True)
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the """ Export interface for c++ call, return subsampling_rate of the
model model
""" """
return self.encoder.embed.subsampling_rate return self.encoder.embed.subsampling_rate
#@jit.to_static @jit.to_static(property=True)
def right_context(self) -> int: def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model """ Export interface for c++ call, return right_context of the model
""" """
return self.encoder.embed.right_context return self.encoder.embed.right_context
#@jit.to_static @jit.to_static(property=True)
def sos_symbol(self) -> int: def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model """ Export interface for c++ call, return sos symbol id of the model
""" """
return self.sos return self.sos
#@jit.to_static @jit.to_static(property=True)
def eos_symbol(self) -> int: def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model """ Export interface for c++ call, return eos symbol id of the model
""" """
return self.eos return self.eos
@jit.to_static @jit.to_static(property=True)
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
paddle.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
# @jit.to_static
def forward_encoder_chunk( def forward_encoder_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -681,28 +694,16 @@ class U2BaseModel(ASRInterface, nn.Layer):
Args: Args:
xs (paddle.Tensor): encoder output, (B, T, D) xs (paddle.Tensor): encoder output, (B, T, D)
Returns: Returns:
paddle.Tensor: activation before ctc paddle.Tensor: activation before ctc. (B, Tmax, odim)
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
# @jit.to_static # @jit.to_static
def is_bidirectional_decoder(self) -> bool: def forward_attention_decoder(self,
""" hyps: paddle.Tensor,
Returns: hyps_lens: paddle.Tensor,
paddle.Tensor: decoder output encoder_out: paddle.Tensor,
""" reverse_weight: float=0.0) -> paddle.Tensor:
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor,
reverse_weight: float=0.0, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple """ Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output hypothesis from ctc prefix beam search and one encoder output
Args: Args:
@ -747,7 +748,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
ctc_weight: float=0.0, ctc_weight: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False): simulate_streaming: bool=False,
reverse_weight: float=0.0):
"""u2 decoding. """u2 decoding.
Args: Args:
@ -766,6 +768,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks (int, optional): num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1. number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
reverse_weight (float, optional): reverse decoder weight, used by `attention_rescoring`.
Raises: Raises:
ValueError: when not support decoding_method. ValueError: when not support decoding_method.
@ -819,7 +822,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks, num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=ctc_weight, ctc_weight=ctc_weight,
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming,
reverse_weight=reverse_weight)
hyps = [hyp] hyps = [hyp]
else: else:
raise ValueError(f"Not support decoding method: {decoding_method}") raise ValueError(f"Not support decoding method: {decoding_method}")
@ -980,6 +984,49 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict): def __init__(self, configs: dict):
super().__init__(configs) super().__init__(configs)
from paddlespeech.s2t.modules.fbank import KaldiFbank
import yaml
import json
import numpy as np
input_dim = configs['input_dim']
process = configs['preprocess_config']
with open(process, encoding="utf-8") as f:
conf = yaml.safe_load(f)
assert isinstance(conf, dict), type(self.conf)
for idx, process in enumerate(conf['process']):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
if process_type == 'fbank_kaldi':
opts.update({'n_mels': input_dim})
opts['dither'] = 0.0
self.fbank = KaldiFbank(**opts)
logger.info(f"{self.__class__.__name__} export: {self.fbank}")
if process_type == 'cmvn_json':
# align with paddlespeech.audio.transform.cmvn:GlobalCMVN
std_floor = 1.0e-20
cmvn = opts['cmvn_path']
if isinstance(cmvn, dict):
cmvn_stats = cmvn
else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
count = cmvn_stats['frame_num']
mean = np.array(cmvn_stats['mean_stat']) / count
square_sums = np.array(cmvn_stats['var_stat'])
var = square_sums / count - mean**2
std = np.maximum(np.sqrt(var), std_floor)
istd = 1.0 / std
self.global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
logger.info(
f"{self.__class__.__name__} export: {self.global_cmvn}")
def forward(self, def forward(self,
feats, feats,
feats_lengths, feats_lengths,
@ -995,9 +1042,25 @@ class U2InferModel(U2Model):
Returns: Returns:
List[List[int]]: best path result List[List[int]]: best path result
""" """
return self.ctc_greedy_search( # dummy code for dy2st
feats, # return self.ctc_greedy_search(
feats_lengths, # feats,
decoding_chunk_size=decoding_chunk_size, # feats_lengths,
num_decoding_left_chunks=num_decoding_left_chunks, # decoding_chunk_size=decoding_chunk_size,
simulate_streaming=simulate_streaming) # num_decoding_left_chunks=num_decoding_left_chunks,
# simulate_streaming=simulate_streaming)
return feats, feats_lengths
def forward_feature(self, x):
"""feature pipeline.
Args:
x (paddle.Tensor): waveform (T,).
Return:
feat (paddle.Tensor): feature (T, D)
"""
x = paddle.cast(x, paddle.float32)
feat = self.fbank(x)
feat = self.global_cmvn(feat)
return feat

@ -111,10 +111,7 @@ class U2STBaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch # 2a. ST-decoder branch
start = time.time() start = time.time()

@ -19,6 +19,7 @@ from typing import Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.align import Linear
@ -45,6 +46,7 @@ class MultiHeadedAttention(nn.Layer):
""" """
super().__init__() super().__init__()
assert n_feat % n_head == 0 assert n_feat % n_head == 0
self.n_feat = n_feat
# We assume d_v always equals d_k # We assume d_v always equals d_k
self.d_k = n_feat // n_head self.d_k = n_feat // n_head
self.h = n_head self.h = n_head
@ -54,6 +56,16 @@ class MultiHeadedAttention(nn.Layer):
self.linear_out = Linear(n_feat, n_feat) self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
def _build_once(self, *args, **kwargs):
super()._build_once(*args, **kwargs)
# if self.self_att:
# self.linear_kv = Linear(self.n_feat, self.n_feat*2)
if not self.training:
self.weight = paddle.concat(
[self.linear_k.weight, self.linear_v.weight], axis=-1)
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
self._built = True
def forward_qkv(self, def forward_qkv(self,
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
@ -73,9 +85,16 @@ class MultiHeadedAttention(nn.Layer):
(#batch, n_head, time2, d_k). (#batch, n_head, time2, d_k).
""" """
n_batch = query.shape[0] n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) if self.training:
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
else:
k, v = F.linear(key, self.weight, self.bias).view(
n_batch, -1, 2 * self.h, self.d_k).split(
2, axis=2)
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
@ -108,10 +127,10 @@ class MultiHeadedAttention(nn.Layer):
# When will `if mask.size(2) > 0` be False? # When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0) # 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4) # 2. jit (16/-1, -1/-1, 16/0, 16/4)
if paddle.shape(mask)[2] > 0: # time2 > 0 if mask.shape[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2) mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1) # for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]] mask = mask[:, :, :, :scores.shape[-1]]
scores = scores.masked_fill(mask, -float('inf')) scores = scores.masked_fill(mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(
scores, axis=-1).masked_fill(mask, scores, axis=-1).masked_fill(mask,
@ -179,7 +198,7 @@ class MultiHeadedAttention(nn.Layer):
# >>> torch.equal(b, c) # True # >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1) # >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True # >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0: if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val) # last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1) key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2) k = paddle.concat([key_cache, k], axis=2)
@ -188,8 +207,9 @@ class MultiHeadedAttention(nn.Layer):
# non-trivial to calculate `next_cache_start` here. # non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1) new_cache = paddle.concat((k, v), axis=-1)
scores = paddle.matmul(q, # scores = paddle.matmul(q,
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) # k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache return self.forward_attention(v, scores, mask), new_cache
@ -270,7 +290,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
and `head * d_k == size` and `head * d_k == size`
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed # when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
@ -287,7 +307,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# >>> torch.equal(b, c) # True # >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1) # >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True # >>> torch.equal(d[0], d[1]) # True
if paddle.shape(cache)[0] > 0: if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val) # last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1) key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2) k = paddle.concat([key_cache, k], axis=2)
@ -301,19 +321,23 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition, # Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming. # and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd) # matrix_bd = self.rel_shift(matrix_bd)

@ -40,6 +40,13 @@ class GlobalCMVN(nn.Layer):
self.register_buffer("mean", mean) self.register_buffer("mean", mean)
self.register_buffer("istd", istd) self.register_buffer("istd", istd)
def __repr__(self):
return ("{name}(mean={mean}, istd={istd}, norm_var={norm_var})".format(
name=self.__class__.__name__,
mean=self.mean,
istd=self.istd,
norm_var=self.norm_var))
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
""" """
Args: Args:

@ -127,11 +127,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T] x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding # mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0 if mask_pad.shape[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0) x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0: if self.lorder > 0:
if paddle.shape(cache)[2] == 0: # cache_t == 0 if cache.shape[2] == 0: # cache_t == 0
x = nn.functional.pad( x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else: else:
@ -161,7 +161,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
# mask batch padding # mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0 if mask_pad.shape[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0) x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C] x = x.transpose([0, 2, 1]) # [B, T, C]

@ -140,9 +140,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor tgt_mask = tgt_mask & m
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt) x, _ = self.embed(tgt)
for layer in self.decoders: for layer in self.decoders:
@ -153,9 +151,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
if self.use_output_layer: if self.use_output_layer:
x = self.output_layer(x) x = self.output_layer(x)
# TODO(Hui Zhang): reduce_sum not support bool type olens = tgt_mask.sum(1)
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, paddle.to_tensor(0.0), olens return x, paddle.to_tensor(0.0), olens
def forward_one_step( def forward_one_step(
@ -247,7 +243,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
] ]
# batch decoding # batch decoding
ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L) ys_mask = subsequent_mask(ys.shape[-1]).unsqueeze(0) # (B,L,L)
xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T)
logp, states = self.forward_one_step( logp, states = self.forward_one_step(
xs, xs_mask, ys, ys_mask, cache=batch_state) xs, xs_mask, ys, ys_mask, cache=batch_state)
@ -343,7 +339,7 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
""" """
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
r_x = paddle.to_tensor(0.0) r_x = paddle.zeros([1])
if reverse_weight > 0.0: if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens) ys_in_lens)

@ -114,10 +114,7 @@ class DecoderLayer(nn.Layer):
], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}" ], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}"
tgt_q = tgt[:, -1:, :] tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :] residual = residual[:, -1:, :]
# TODO(Hui Zhang): slice not support bool type tgt_q_mask = tgt_mask[:, -1:, :]
# tgt_q_mask = tgt_mask[:, -1:, :]
tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast(
paddle.bool)
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(

@ -89,7 +89,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D] self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
@ -97,9 +97,8 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model)) -(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, 1::2] = paddle.cos(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term)
self.pe = self.pe.unsqueeze(0) #[1, T, D]
def forward(self, x: paddle.Tensor, def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
@ -111,12 +110,10 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
""" """
T = x.shape[1]
assert offset + x.shape[ assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len) offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]]
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -165,6 +162,5 @@ class RelPositionalEncoding(PositionalEncoding):
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len) offset, x.shape[1], self.max_len)
x = x * self.xscale x = x * self.xscale
#TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)

@ -164,12 +164,8 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor xs, pos_emb, masks = self.embed(xs, masks, offset=0)
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) mask_pad = ~masks
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask( chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
@ -215,11 +211,8 @@ class BaseEncoder(nn.Layer):
same shape as the original cnn_cache same shape as the original cnn_cache
""" """
assert xs.shape[0] == 1 # batch size must be one assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility # tmp_masks is just for interface compatibility, [B=1, C=1, T]
# TODO(Hui Zhang): stride_slice not support bool tensor tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool)
# tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
@ -228,9 +221,8 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim) # after embed, xs=(B=1, chunk_size, hidden-dim)
elayers = paddle.shape(att_cache)[0] elayers, _, cache_t1, _ = att_cache.shape
cache_t1 = paddle.shape(att_cache)[2] chunk_size = xs.shape[1]
chunk_size = paddle.shape(xs)[1]
attention_key_size = cache_t1 + chunk_size attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention` # only used when using `RelPositionMultiHeadedAttention`
@ -249,25 +241,30 @@ class BaseEncoder(nn.Layer):
for i, layer in enumerate(self.encoders): for i, layer in enumerate(self.encoders):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
# WARNING: eliminate if-else cond op in graph
# tensor zeros([0,0,0,0]) support [i:i+1] slice, will return zeros([0,0,0,0]) tensor
# raw code as below:
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
# cnn_cache=cnn_cache[i:i+1] if cnn_cache.shape[0] > 0 else cnn_cache,
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, xs,
att_mask, att_mask,
pos_emb, pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, att_cache=att_cache[i:i + 1],
cnn_cache=cnn_cache[i:i + 1] cnn_cache=cnn_cache[i:i + 1], )
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim r_cnn_cache.append(new_cnn_cache) # add elayer dim
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# r_att_cache (elayers, head, T, d_k*2) # r_att_cache (elayers, head, T, d_k*2)
# r_cnn_cache elayers, B=1, hidden-dim, cache_t2) # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
r_att_cache = paddle.concat(r_att_cache, axis=0) r_att_cache = paddle.concat(r_att_cache, axis=0)
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
@ -397,11 +394,7 @@ class TransformerEncoder(BaseEncoder):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor xs, pos_emb, masks = self.embed(xs, masks, offset=0)
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None: if cache is None:
cache = [None for _ in range(len(self.encoders))] cache = [None for _ in range(len(self.encoders))]
new_cache = [] new_cache = []

@ -0,0 +1,72 @@
import paddle
from paddle import nn
from paddlespeech.audio.compliance import kaldi
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['KaldiFbank']
class KaldiFbank(nn.Layer):
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.0):
"""
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant. Default 0.0
"""
super().__init__()
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def forward(self, x: paddle.Tensor):
"""
Args:
x (paddle.Tensor): shape (Ti).
Not support: [Time, Channel] and Batch mode.
Returns:
paddle.Tensor: (T, D)
"""
assert x.ndim == 1
feat = kaldi.fbank(
x.unsqueeze(0), # append channel dim, (C, Ti)
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=self.dither,
energy_floor=self.energy_floor,
sr=self.fs)
assert feat.ndim == 2 # (T,D)
return feat

@ -85,7 +85,7 @@ class CTCLoss(nn.Layer):
Returns: Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
""" """
B = paddle.shape(logits)[0] B = logits.shape[0]
# warp-ctc need logits, and do softmax on logits by itself # warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1] # warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
@ -158,7 +158,7 @@ class LabelSmoothingLoss(nn.Layer):
Returns: Returns:
loss (paddle.Tensor) : The KL loss, scalar float value loss (paddle.Tensor) : The KL loss, scalar float value
""" """
B, T, D = paddle.shape(x) B, T, D = x.shape
assert D == self.size assert D == self.size
x = x.reshape((-1, self.size)) x = x.reshape((-1, self.size))
target = target.reshape([-1]) target = target.reshape([-1])

@ -109,12 +109,7 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]] [1, 1, 1]]
""" """
ret = paddle.ones([size, size], dtype=paddle.bool) ret = paddle.ones([size, size], dtype=paddle.bool)
#TODO(Hui Zhang): tril not support bool return paddle.tril(ret)
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask( def subsequent_chunk_mask(

@ -139,8 +139,8 @@ class Conv2dSubsampling4(Conv2dSubsampling):
""" """
x = x.unsqueeze(1) # (b, c=1, t, f) x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x) x = self.conv(x)
b, c, t, f = paddle.shape(x) b, c, t, f = x.shape
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]))
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
@ -192,8 +192,8 @@ class Conv2dSubsampling6(Conv2dSubsampling):
""" """
x = x.unsqueeze(1) # (b, c, t, f) x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x) x = self.conv(x)
b, c, t, f = paddle.shape(x) b, c, t, f = x.shape
x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]))
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
@ -245,6 +245,7 @@ class Conv2dSubsampling8(Conv2dSubsampling):
""" """
x = x.unsqueeze(1) # (b, c, t, f) x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x) x = self.conv(x)
x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) b, c, t, f = x.shape
x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]))
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]

@ -184,13 +184,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2) pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum( numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets)) denominator = paddle.sum(mask)
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)

@ -22,7 +22,6 @@ from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
@ -610,22 +609,15 @@ class PaddleASRConnectionHanddler:
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id) self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1,
self.model.sos, self.model.eos)
decoder_out, r_decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
self.model.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) # (beam_size, max_hyps_len, vocab_size)
decoder_out = decoder_out.numpy() decoder_out, r_decoder_out = self.model.forward_attention_decoder(
hyps_pad, hyps_lens, self.encoder_out, self.model.reverse_weight)
decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder. # conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy() r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring

@ -0,0 +1,156 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddlespeech.s2t # noqa: F401
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
# from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
def reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
pad_value: float=-1.0) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(paddle.flip(y[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad
def naive_reverse_pad_list_with_sos_eos(r_hyps,
r_hyps_lens,
sos=5000,
eos=5000,
ignore_id=-1):
r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(ignore_id))
r_hyps, _ = add_sos_eos(r_hyps, sos, eos, ignore_id)
return r_hyps
def reverse_pad_list_with_sos_eos(r_hyps,
r_hyps_lens,
sos=5000,
eos=5000,
ignore_id=-1):
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = paddle.max(r_hyps_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
r_hyps = paddle_gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
B = r_hyps.shape[0]
_sos = paddle.ones([B, 1], dtype=r_hyps.dtype) * sos
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
r_hyps = paddle.concat([_sos, r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return r_hyps
class TestU2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.sos = 5000
self.eos = 5000
self.ignore_id = -1
self.reverse_hyps = paddle.to_tensor([[4, 3, 2, 1, -1],
[5, 4, 3, 2, 1]])
self.reverse_hyps_sos_eos = paddle.to_tensor(
[[self.sos, 4, 3, 2, 1, self.eos], [self.sos, 5, 4, 3, 2, 1]])
self.hyps = paddle.to_tensor([[1, 2, 3, 4, -1], [1, 2, 3, 4, 5]])
self.hyps_lens = paddle.to_tensor([4, 5], paddle.int32)
def test_reverse_pad_list(self):
r_hyps = reverse_pad_list(self.hyps, self.hyps_lens)
self.assertSequenceEqual(r_hyps.tolist(), self.reverse_hyps.tolist())
def test_naive_reverse_pad_list_with_sos_eos(self):
r_hyps_sos_eos = naive_reverse_pad_list_with_sos_eos(self.hyps,
self.hyps_lens)
self.assertSequenceEqual(r_hyps_sos_eos.tolist(),
self.reverse_hyps_sos_eos.tolist())
def test_static_reverse_pad_list_with_sos_eos(self):
r_hyps_sos_eos_static = reverse_pad_list_with_sos_eos(self.hyps,
self.hyps_lens)
self.assertSequenceEqual(r_hyps_sos_eos_static.tolist(),
self.reverse_hyps_sos_eos.tolist())
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save