add from_config function to ds2_oneline and ds2

pull/735/head
huangyuxin 3 years ago
parent 7a3d164122
commit 718ae52e3f

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -38,8 +38,6 @@ from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -123,40 +121,20 @@ class DeepSpeech2Trainer(Trainer):
return total_loss, num_seen_utts return total_loss, num_seen_utts
def setup_model(self): def setup_model(self):
config = self.config config = self.config.clone()
if hasattr(self, "train_loader"): config.defrost()
config.defrost() assert (self.train_loader.collate_fn.feature_size ==
config.model.feat_size = self.train_loader.collate_fn.feature_size self.test_loader.collate_fn.feature_size)
config.model.dict_size = self.train_loader.collate_fn.vocab_size assert (self.train_loader.collate_fn.vocab_size ==
config.freeze() self.test_loader.collate_fn.vocab_size)
elif hasattr(self, "test_loader"): config.model.feat_size = self.train_loader.collate_fn.feature_size
config.defrost() config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.model.feat_size = self.test_loader.collate_fn.feature_size config.freeze()
config.model.dict_size = self.test_loader.collate_fn.vocab_size
config.freeze()
else:
raise Exception("Please setup the dataloader first")
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
model = DeepSpeech2Model( model = DeepSpeech2Model.from_config(config.model)
feat_size=config.model.feat_size,
dict_size=config.model.dict_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online': elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline( model = DeepSpeech2ModelOnline.from_config(config.model)
feat_size=config.model.feat_size,
dict_size=config.model.dict_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
else: else:
raise Exception("wrong model type") raise Exception("wrong model type")
if self.parallel: if self.parallel:
@ -194,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
@ -217,19 +198,29 @@ class DeepSpeech2Trainer(Trainer):
config.collator.augmentation_config = "" config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config) collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn_train, collate_fn=collate_fn_train,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
print("feature_size", self.train_loader.collate_fn.feature_size)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!") self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer):
@ -371,20 +362,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
if self.args.model_type == 'offline': static_model = infer_model.export()
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
elif self.args.model_type == 'online':
static_model = infer_model.export()
else:
raise Exception("wrong model type")
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
@ -408,63 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
'''
def setup_model(self):
config = self.config
if self.args.model_type == 'offline':
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
else:
raise Exception("Wrong model type")
self.model = model
logger.info("Setup model!")
'''
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
# return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.
""" """

@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools.summary(model) layer_tools.summary(model)
return model return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
return model
class DeepSpeech2InferModel(DeepSpeech2Model): class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self, def __init__(self,
@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return probs return probs
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model

@ -51,8 +51,9 @@ class CRNNEncoder(nn.Layer):
self.use_gru = use_gru self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
i_size = self.conv.output_dim self.output_dim = self.conv.output_dim
i_size = self.conv.output_dim
self.rnn = nn.LayerList() self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList() self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList() self.fc_layers_list = nn.LayerList()
@ -82,16 +83,18 @@ class CRNNEncoder(nn.Layer):
num_layers=1, num_layers=1,
direction=rnn_direction)) direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size)) self.layernorm_list.append(nn.LayerNorm(layernorm_size))
self.output_dim = layernorm_size
fc_input_size = layernorm_size fc_input_size = layernorm_size
for i in range(self.num_fc_layers): for i in range(self.num_fc_layers):
self.fc_layers_list.append( self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i])) nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i] fc_input_size = fc_layers_size_list[i]
self.output_dim = fc_layers_size_list[i]
@property @property
def output_size(self): def output_size(self):
return self.fc_layers_size_list[-1] return self.output_dim
def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs """Compute Encoder outputs
@ -190,9 +193,6 @@ class CRNNEncoder(nn.Layer):
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_stride
end = start + chunk_size end = start + chunk_size
# end = min(start + chunk_size, max_len)
# if (end - start < receptive_field_length):
# break
x_chunk = padded_x[:, start:end, :] x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0, x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
@ -221,8 +221,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type text_data: Variable :type text_data: Variable
:param audio_len: Valid sequence length data layer. :param audio_len: Valid sequence length data layer.
:type audio_len: Variable :type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription. :param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int :type dict_size: int
:param num_conv_layers: Number of stacking convolution layers. :param num_conv_layers: Number of stacking convolution layers.
@ -231,6 +229,10 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type num_rnn_layers: int :type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells). :param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int :type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False. :param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool :type use_gru: bool
:return: A tuple of an output unnormalized log probability layer ( :return: A tuple of an output unnormalized log probability layer (
@ -274,7 +276,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=fc_layers_size_list, fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru) use_gru=use_gru)
assert (self.encoder.output_size == fc_layers_size_list[-1])
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
@ -337,7 +338,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
Returns Returns
------- -------
DeepSpeech2Model DeepSpeech2ModelOnline
The model built from pretrained result. The model built from pretrained result.
""" """
model = cls(feat_size=dataloader.collate_fn.feature_size, model = cls(feat_size=dataloader.collate_fn.feature_size,
@ -355,6 +356,29 @@ class DeepSpeech2ModelOnline(nn.Layer):
layer_tools.summary(model) layer_tools.summary(model)
return model return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru)
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self, def __init__(self,
@ -392,7 +416,7 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim] self.encoder.feat_size], #[B, chunk_size, feat_dim]
dtype='float32'), # audio, [B,T,D] dtype='float32'),
paddle.static.InputSpec(shape=[None], paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B] dtype='int64'), # audio_length, [B]
paddle.static.InputSpec( paddle.static.InputSpec(

@ -36,17 +36,17 @@ collator:
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 4 num_rnn_layers: 3
rnn_layer_size: 1024 rnn_layer_size: 1024
rnn_direction: bidirect rnn_direction: forward # [forward, bidirect]
num_fc_layers: 2 num_fc_layers: 1
fc_layers_size_list: 512, 256 fc_layers_size_list: 512,
use_gru: True use_gru: True
training: training:
n_epoch: 50 n_epoch: 50
lr: 2e-3 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83 # 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
@ -55,7 +55,7 @@ training:
latest_n: 5 latest_n: 5
decoding: decoding:
batch_size: 64 batch_size: 32
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm

@ -106,18 +106,34 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(loss.numel(), 1) self.assertEqual(loss.numel(), 1)
def test_ds2_6(self): def test_ds2_6(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='bidirect',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_7(self):
use_gru = False
model = DeepSpeech2ModelOnline( model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim, feat_size=self.feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
num_rnn_layers=1, num_rnn_layers=1,
rnn_size=1024, rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=True) use_gru=use_gru)
model.eval() model.eval()
paddle.device.set_device("cpu") paddle.device.set_device("cpu")
de_ch_size = 9 de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len) self.audio, self.audio_len)
@ -126,99 +142,44 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1] decode_max_len = eouts.shape[1]
print("dml", decode_max_len)
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(
paddle.sum(
paddle.abs(paddle.subtract(eouts_lens, eouts_lens_by_chk))), 0)
self.assertEqual(
paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0)
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
self.assertEqual( if use_gru == False:
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) self.assertEqual(
""" paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
print ("conv_x", conv_x)
print ("conv_x_by_chk", conv_x_by_chk)
print ("final_state_list", final_state_list)
#print ("final_state_list_by_chk", final_state_list_by_chk)
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))))
print (paddle.allclose(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))))
print (paddle.allclose(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))))
print (paddle.allclose(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:]))
"""
"""
def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate,
receptive_field_length):
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
x_chunk_list = []
x_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
x_chunk_list.append(x_chunk)
x_chunk_lens_list.append(x_chunk_lens)
return x_chunk_list, x_chunk_lens_list
def test_ds2_7(self): def test_ds2_8(self):
use_gru = True
model = DeepSpeech2ModelOnline( model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim, feat_size=self.feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
num_rnn_layers=1, num_rnn_layers=1,
rnn_size=1024, rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=True) use_gru=use_gru)
model.eval() model.eval()
paddle.device.set_device("cpu") paddle.device.set_device("cpu")
de_ch_size = 9 de_ch_size = 8
audio_chunk_list, audio_chunk_lens_list = self.split_into_chunk(
self.audio, self.audio_len, de_ch_size,
model.encoder.conv.subsampling_rate,
model.encoder.conv.receptive_field_length)
eouts_prefix = None
eouts_lens_prefix = None
chunk_state_list = [None] * model.encoder.num_rnn_layers
for i, audio_chunk in enumerate(audio_chunk_list):
audio_chunk_lens = audio_chunk_lens_list[i]
eouts_prefix, eouts_lens_prefix, chunk_state_list = model.decode_prob_by_chunk(
audio_chunk, audio_chunk_lens, eouts_prefix, eouts_lens_prefix,
chunk_state_list)
# print (i, probs_pre_chunks.shape)
probs, eouts, eouts_lens, final_state_list = model.decode_prob(
self.audio, self.audio_len)
decode_max_len = probs.shape[1] eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
probs_pre_chunks = probs_pre_chunks[:, :decode_max_len, :] self.audio, self.audio_len)
self.assertEqual(paddle.allclose(probs, probs_pre_chunks), True) eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
""" self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save