pull/2212/head
Hui Zhang 3 years ago
parent 549d477592
commit 53d6baff0b

@ -376,7 +376,8 @@ def _get_mel_banks(num_bins: int,
center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
# (1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
mel = _mel_scale(fft_bin_width * paddle.arange(
num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
# (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)

@ -68,7 +68,7 @@ class U2Infer():
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")

@ -462,11 +462,13 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.clone(),
self.args.checkpoint_path)
batch_size = 1
batch_size = 1
feat_dim = self.test_loader.feat_dim
model_size = self.config.encoder_conf.output_size
num_left_chunks = -1
logger.info(f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}")
logger.info(
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}"
)
return infer_model, (batch_size, feat_dim, model_size, num_left_chunks)
@ -479,29 +481,29 @@ class U2Tester(U2Trainer):
assert isinstance(input_spec, (list, tuple)), type(input_spec)
batch_size, feat_dim, model_size, num_left_chunks = input_spec
######################## infer_model.forward_encoder_chunk ############
input_spec = [
# (T,), int16
paddle.static.InputSpec(shape=[None], dtype='int16'),
]
infer_model.forward_feature = paddle.jit.to_static(infer_model.forward_feature, input_spec=input_spec)
infer_model.forward_feature = paddle.jit.to_static(
infer_model.forward_feature, input_spec=input_spec)
######################### infer_model.forward_encoder_chunk ############
input_spec = [
# xs, (B, T, D)
paddle.static.InputSpec(shape=[batch_size, None, feat_dim], dtype='float32'),
paddle.static.InputSpec(
shape=[batch_size, None, feat_dim], dtype='float32'),
# offset, int, but need be tensor
paddle.static.InputSpec(shape=[1], dtype='int32'),
paddle.static.InputSpec(shape=[1], dtype='int32'),
# required_cache_size, int
num_left_chunks,
# att_cache
paddle.static.InputSpec(
shape=[None, None, None, None],
dtype='float32'),
shape=[None, None, None, None], dtype='float32'),
# cnn_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
shape=[None, None, None, None], dtype='float32')
]
infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec)
@ -509,12 +511,12 @@ class U2Tester(U2Trainer):
######################### infer_model.ctc_activation ########################
input_spec = [
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
######################### infer_model.forward_attention_decoder ########################
input_spec = [
# hyps, (B, U)
@ -522,15 +524,19 @@ class U2Tester(U2Trainer):
# hyps_lens, (B,)
paddle.static.InputSpec(shape=[None], dtype='int64'),
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec)
# jit save
logger.info(f"export save: {self.args.export_path}")
paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True)
paddle.jit.save(
infer_model,
self.args.export_path,
combine_params=True,
skip_forward=True)
# test dy2static
def flatten(out):
@ -551,7 +557,8 @@ class U2Tester(U2Trainer):
required_cache_size = num_left_chunks
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(
xs1, offset, required_cache_size, att_cache, cnn_cache)
# load static model
from paddle.jit.layer import Layer

@ -545,11 +545,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
logger.debug(f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
logger.debug(
f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining
# ctc score in ln domain
decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out)
decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
encoder_out)
# Only use decoder score for rescoring
best_score = -float('inf')
@ -561,7 +563,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
logger.debug(f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}")
logger.debug(
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
)
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
@ -933,9 +937,7 @@ class U2InferModel(U2Model):
if process_type == 'fbank_kaldi':
opts.update({'n_mels': input_dim})
opts['dither'] = 0.0
self.fbank = KaldiFbank(
**opts
)
self.fbank = KaldiFbank(**opts)
logger.info(f"{self.__class__.__name__} export: {self.fbank}")
if process_type == 'cmvn_json':
# align with paddlespeech.audio.transform.cmvn:GlobalCMVN
@ -956,7 +958,8 @@ class U2InferModel(U2Model):
self.global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
logger.info(f"{self.__class__.__name__} export: {self.global_cmvn}")
logger.info(
f"{self.__class__.__name__} export: {self.global_cmvn}")
def forward(self,
feats,
@ -994,4 +997,4 @@ class U2InferModel(U2Model):
x = paddle.cast(x, paddle.float32)
feat = self.fbank(x)
feat = self.global_cmvn(feat)
return feat
return feat

@ -41,12 +41,11 @@ class GlobalCMVN(nn.Layer):
self.register_buffer("istd", istd)
def __repr__(self):
return (
"{name}(mean={mean}, istd={istd}, norm_var={norm_var})".format(
name=self.__class__.__name__,
mean=self.mean,
istd=self.istd,
norm_var=self.norm_var))
return ("{name}(mean={mean}, istd={istd}, norm_var={norm_var})".format(
name=self.__class__.__name__,
mean=self.mean,
istd=self.istd,
norm_var=self.norm_var))
def forward(self, x: paddle.Tensor):
"""
@ -58,4 +57,4 @@ class GlobalCMVN(nn.Layer):
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
return x

@ -256,10 +256,11 @@ class BaseEncoder(nn.Layer):
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
# cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1],
cnn_cache=cnn_cache[i:i+1],
)
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1],
cnn_cache=cnn_cache[i:i + 1], )
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])

@ -1,19 +1,17 @@
import paddle
from paddle import nn
from paddlespeech.audio.compliance import kaldi
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['KaldiFbank']
class KaldiFbank(nn.Layer):
def __init__(self,
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
@ -62,7 +60,7 @@ class KaldiFbank(nn.Layer):
assert x.ndim == 1
feat = kaldi.fbank(
x.unsqueeze(0), # append channel dim, (C, Ti)
x.unsqueeze(0), # append channel dim, (C, Ti)
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
@ -70,5 +68,5 @@ class KaldiFbank(nn.Layer):
energy_floor=self.energy_floor,
sr=self.fs)
assert feat.ndim == 2 # (T,D)
assert feat.ndim == 2 # (T,D)
return feat

@ -80,7 +80,6 @@ class PaddleASRConnectionHanddler:
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"

Loading…
Cancel
Save