pull/780/head
huangyuxin 4 years ago
parent 08b68e4b8f
commit 718407b77d

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Trainer for DeepSpeech2 model.""" """Trainer for DeepSpeech2 model."""
import os
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
@ -53,5 +55,7 @@ if __name__ == "__main__":
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args) main(config, args)

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import random
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -53,6 +55,7 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=1e-6, # the coeff of weight decay weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs n_epoch=50, # train epochs
seed=1024, #train seed
)) ))
if config is not None: if config is not None:
@ -61,6 +64,13 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
start = time.time() start = time.time()

@ -52,7 +52,10 @@ if __name__ == "__main__":
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
# Setting for profiling # Setting for profiling
pr = cProfile.Profile() pr = cProfile.Profile()
pr.runcall(main, config, args) pr.runcall(main, config, args)

@ -55,7 +55,7 @@ class U2Trainer(Trainer):
log_interval=100, # steps log_interval=100, # steps
accum_grad=1, # accum grad by # steps accum_grad=1, # accum grad by # steps
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
)) seed=1024, ))
default.optim = 'adam' default.optim = 'adam'
default.optim_conf = CfgNode( default.optim_conf = CfgNode(
dict( dict(
@ -75,6 +75,12 @@ class U2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training

@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
Args: Args:
x (Tensor): [B, feature_size, D] x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B] x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Returns: Return:
x (Tensor): encoder outputs, [B, size, D] x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B] x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
""" """
if init_state_h_box is not None: if init_state_h_box is not None:
init_state_list = None init_state_list = None
@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
if self.use_gru == True: if self.use_gru == True:
final_chunk_state_h_box = paddle.concat( final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0) final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) final_chunk_state_c_box = init_state_c_box
else: else:
final_chunk_state_h_list = [ final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B] x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder decoder_chunk_size: The chunk size of decoder
Returns: Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
""" """
subsampling_rate = self.conv.subsampling_rate subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length receptive_field_length = self.conv.receptive_field_length
@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class DeepSpeech2ModelOnline(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online. """The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer. :param audio: Audio spectrogram data layer.
:type audio_data: Variable :type audio: Variable
:param text_data: Transcription text data layer. :param text: Transcription text data layer.
:type text_data: Variable :type text: Variable
:param audio_len: Valid sequence length data layer. :param audio_len: Valid sequence length data layer.
:type audio_len: Variable :type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription. :param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int :type dict_size: int
:param num_conv_layers: Number of stacking convolution layers. :param num_conv_layers: Number of stacking convolution layers.

@ -143,7 +143,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1] decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru == False:

Loading…
Cancel
Save