Merge branch 'HEAD_1' into ds2_online

pull/780/head
huangyuxin 4 years ago
commit 9068c0d4f9

@ -55,7 +55,7 @@ if __name__ == "__main__":
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
if config.training.seed != None: if config.training.seed is not None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args) main(config, args)

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
@ -64,7 +63,7 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
if config.training.seed != None: if config.training.seed is not None:
self.set_seed(config.training.seed) self.set_seed(config.training.seed)
def set_seed(self, seed): def set_seed(self, seed):

@ -52,10 +52,7 @@ if __name__ == "__main__":
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
# Setting for profiling # Setting for profiling
pr = cProfile.Profile() pr = cProfile.Profile()
pr.runcall(main, config, args) pr.runcall(main, config, args)

@ -55,7 +55,7 @@ class U2Trainer(Trainer):
log_interval=100, # steps log_interval=100, # steps
accum_grad=1, # accum grad by # steps accum_grad=1, # accum grad by # steps
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
seed=1024, )) ))
default.optim = 'adam' default.optim = 'adam'
default.optim_conf = CfgNode( default.optim_conf = CfgNode(
dict( dict(
@ -75,12 +75,6 @@ class U2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling4

@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer): class CRNNEncoder(nn.Layer):
@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size rnn_input_size = i_size
else: else:
rnn_input_size = layernorm_size rnn_input_size = layernorm_size
if use_gru == True: if use_gru is True:
self.rnn.append( self.rnn.append(
nn.GRU( nn.GRU(
input_size=rnn_input_size, input_size=rnn_input_size,
@ -113,7 +113,7 @@ class CRNNEncoder(nn.Layer):
if init_state_h_box is not None: if init_state_h_box is not None:
init_state_list = None init_state_list = None
if self.use_gru == True: if self.use_gru is True:
init_state_h_list = paddle.split( init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0) init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list init_state_list = init_state_h_list
@ -139,7 +139,7 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x) x = self.fc_layers_list[i](x)
x = F.relu(x) x = F.relu(x)
if self.use_gru == True: if self.use_gru is True:
final_chunk_state_h_box = paddle.concat( final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0) final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box final_chunk_state_c_box = init_state_c_box

@ -46,7 +46,7 @@ model:
training: training:
n_epoch: 50 n_epoch: 50
lr: 2e-3 lr: 2e-3
lr_decay: 0.9 # 0.83 lr_decay: 0.91 # 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100

@ -143,10 +143,10 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1] decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)

Loading…
Cancel
Save