|
|
@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
|
|
|
|
num_fc_layers=2,
|
|
|
|
num_fc_layers=2,
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
use_gru=True, #Use gru if set True. Use simple rnn if set False.
|
|
|
|
use_gru=True, #Use gru if set True. Use simple rnn if set False.
|
|
|
|
|
|
|
|
blank_id=0, # index of blank in vocob.txt
|
|
|
|
))
|
|
|
|
))
|
|
|
|
if config is not None:
|
|
|
|
if config is not None:
|
|
|
|
config.merge_from_other_cfg(default)
|
|
|
|
config.merge_from_other_cfg(default)
|
|
|
@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
|
|
|
|
rnn_direction='forward',
|
|
|
|
rnn_direction='forward',
|
|
|
|
num_fc_layers=2,
|
|
|
|
num_fc_layers=2,
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
use_gru=False):
|
|
|
|
use_gru=False,
|
|
|
|
|
|
|
|
blank_id=0):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.encoder = CRNNEncoder(
|
|
|
|
self.encoder = CRNNEncoder(
|
|
|
|
feat_size=feat_size,
|
|
|
|
feat_size=feat_size,
|
|
|
@ -284,7 +286,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
odim=dict_size, # <blank> is in vocab
|
|
|
|
odim=dict_size, # <blank> is in vocab
|
|
|
|
enc_n_units=self.encoder.output_size,
|
|
|
|
enc_n_units=self.encoder.output_size,
|
|
|
|
blank_id=0, # first token is <blank>
|
|
|
|
blank_id=blank_id,
|
|
|
|
dropout_rate=0.0,
|
|
|
|
dropout_rate=0.0,
|
|
|
|
reduction=True, # sum
|
|
|
|
reduction=True, # sum
|
|
|
|
batch_average=True) # sum / batch_size
|
|
|
|
batch_average=True) # sum / batch_size
|
|
|
@ -353,7 +355,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
|
|
|
blank_id=config.model.blank_id)
|
|
|
|
infos = Checkpoint().load_parameters(
|
|
|
|
infos = Checkpoint().load_parameters(
|
|
|
|
model, checkpoint_path=checkpoint_path)
|
|
|
|
model, checkpoint_path=checkpoint_path)
|
|
|
|
logger.info(f"checkpoint info: {infos}")
|
|
|
|
logger.info(f"checkpoint info: {infos}")
|
|
|
@ -380,7 +383,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
|
|
|
|
rnn_direction=config.rnn_direction,
|
|
|
|
rnn_direction=config.rnn_direction,
|
|
|
|
num_fc_layers=config.num_fc_layers,
|
|
|
|
num_fc_layers=config.num_fc_layers,
|
|
|
|
fc_layers_size_list=config.fc_layers_size_list,
|
|
|
|
fc_layers_size_list=config.fc_layers_size_list,
|
|
|
|
use_gru=config.use_gru)
|
|
|
|
use_gru=config.use_gru,
|
|
|
|
|
|
|
|
blank_id=config.blank_id)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -394,7 +398,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
|
|
|
|
rnn_direction='forward',
|
|
|
|
rnn_direction='forward',
|
|
|
|
num_fc_layers=2,
|
|
|
|
num_fc_layers=2,
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
fc_layers_size_list=[512, 256],
|
|
|
|
use_gru=False):
|
|
|
|
use_gru=False,
|
|
|
|
|
|
|
|
blank_id=0):
|
|
|
|
super().__init__(
|
|
|
|
super().__init__(
|
|
|
|
feat_size=feat_size,
|
|
|
|
feat_size=feat_size,
|
|
|
|
dict_size=dict_size,
|
|
|
|
dict_size=dict_size,
|
|
|
@ -404,7 +409,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
|
|
|
|
rnn_direction=rnn_direction,
|
|
|
|
rnn_direction=rnn_direction,
|
|
|
|
num_fc_layers=num_fc_layers,
|
|
|
|
num_fc_layers=num_fc_layers,
|
|
|
|
fc_layers_size_list=fc_layers_size_list,
|
|
|
|
fc_layers_size_list=fc_layers_size_list,
|
|
|
|
use_gru=use_gru)
|
|
|
|
use_gru=use_gru,
|
|
|
|
|
|
|
|
blank_id=blank_id)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
|
|
|
|
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
|
|
|
|
chunk_state_c_box):
|
|
|
|
chunk_state_c_box):
|
|
|
|