修改了deepspeech2.py部分LSTM和GRU的代码,增加了LayerNorm

pull/735/head
huangyuxin 4 years ago
parent ce1e8ab5b6
commit 2cacbaf48e

@ -127,7 +127,8 @@ class DeepSpeech2Trainer(Trainer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -374,7 +375,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")

@ -25,6 +25,11 @@ from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from paddle.nn import LSTM, GRU
from paddle.nn import LayerNorm
from paddle.nn import LayerList
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode'] __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode']
@ -38,25 +43,50 @@ class CRNNEncoder(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online=True):
super().__init__() super().__init__()
self.rnn_size = rnn_size self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.apply_online = apply_online
self.conv = ConvStack(feat_size, num_conv_layers) self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack i_size = self.conv.output_height # H after conv stack
self.rnn = LayerList()
self.layernorm_list = LayerList()
if (apply_online == True):
rnn_direction = 'forward'
else:
rnn_direction = 'bidirect'
if use_gru == True:
self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
for i in range(1, num_rnn_layers):
self.rnn.append(GRU(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
else:
self.rnn.append(LSTM(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
for i in range(1, num_rnn_layers):
self.rnn.append(LSTM(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
"""
self.rnn = RNNStack( self.rnn = RNNStack(
i_size=i_size, i_size=i_size,
h_size=rnn_size, h_size=rnn_size,
num_stacks=num_rnn_layers, num_stacks=num_rnn_layers,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights)
"""
@property @property
def output_size(self): def output_size(self):
return self.rnn_size * 2 return self.rnn_size
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""Compute Encoder outputs """Compute Encoder outputs
@ -86,7 +116,15 @@ class CRNNEncoder(nn.Layer):
x = x.reshape([0, 0, -1]) #[B, T, C*D] x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part # remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D] print ("x.shape:", x.shape)
x, output_state = self.rnn[0](x, None, x_lens)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D]
x = self.layernorm_list[i](x)
"""
x, x_lens = self.rnn(x, x_lens)
"""
return x, x_lens return x, x_lens
@ -141,7 +179,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online = True):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
@ -150,8 +189,9 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
assert (self.encoder.output_size == rnn_size * 2) apply_online=apply_online)
assert (self.encoder.output_size == rnn_size)
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
@ -221,7 +261,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
@ -237,7 +278,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online = True):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
@ -245,7 +287,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
apply_online=apply_online)
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""export model function """export model function

@ -36,10 +36,11 @@ collator:
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 4
rnn_layer_size: 1024 rnn_layer_size: 1024
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
apply_online: False
training: training:
n_epoch: 50 n_epoch: 50

@ -40,6 +40,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
apply_online: False
training: training:
n_epoch: 50 n_epoch: 50

@ -41,6 +41,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
apply_online: True
training: training:
n_epoch: 10 n_epoch: 10

Loading…
Cancel
Save