Merge pull request #735 from Jackwaterveg/ds2_online

Ds2 online
pull/754/head
Hui Zhang 4 years ago committed by GitHub
commit 38d95784e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args) print_arguments(args)
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:

@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:

@ -35,11 +35,15 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args() args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:

@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
_C = CfgNode()
def get_cfg_defaults(model_type='offline'):
_C = CfgNode()
_C.data = ManifestDataset.params() _C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params() _C.collator = SpeechCollator.params()
_C.model = DeepSpeech2Model.params()
_C.training = DeepSpeech2Trainer.params() _C.training = DeepSpeech2Trainer.params()
_C.decoding = DeepSpeech2Tester.params() _C.decoding = DeepSpeech2Tester.params()
if model_type == 'offline':
_C.model = DeepSpeech2Model.params()
def get_cfg_defaults(): else:
_C.model = DeepSpeech2ModelOnline.params()
"""Get a yacs CfgNode object with default values for my_project.""" """Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered # Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern # This is for the "local variable" use pattern

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer):
return total_loss, num_seen_utts return total_loss, num_seen_utts
def setup_model(self): def setup_model(self):
config = self.config config = self.config.clone()
model = DeepSpeech2Model( config.defrost()
feat_size=self.train_loader.collate_fn.feature_size, assert (self.train_loader.collate_fn.feature_size ==
dict_size=self.train_loader.collate_fn.vocab_size, self.test_loader.collate_fn.feature_size)
num_conv_layers=config.model.num_conv_layers, assert (self.train_loader.collate_fn.vocab_size ==
num_rnn_layers=config.model.num_rnn_layers, self.test_loader.collate_fn.vocab_size)
rnn_size=config.model.rnn_layer_size, config.model.feat_size = self.train_loader.collate_fn.feature_size
use_gru=config.model.use_gru, config.model.dict_size = self.train_loader.collate_fn.vocab_size
share_rnn_weights=config.model.share_rnn_weights) config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer):
config.collator.augmentation_config = "" config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config) collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!") self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer):
@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1) exit(-1)
def export(self): def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = infer_model.export()
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model
logger.info("Setup model!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
# return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.
""" """

@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools.summary(model) layer_tools.summary(model)
return model return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
return model
class DeepSpeech2InferModel(DeepSpeech2Model): class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self, def __init__(self,
@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return probs return probs
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model

@ -0,0 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4
class Conv2dSubsampling4Online(Conv2dSubsampling4):
def __init__(self, idim: int, odim: int, dropout_rate: float):
super().__init__(idim, odim, dropout_rate, None)
self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim
self.receptive_field_length = 2 * (
3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1
def forward(self, x: paddle.Tensor,
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
#b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len

@ -0,0 +1,427 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list
self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
self.output_dim = self.conv.output_dim
i_size = self.conv.output_dim
self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional':
layernorm_size = 2 * rnn_size
elif rnn_direction == 'forward':
layernorm_size = rnn_size
else:
raise Exception("Wrong rnn direction")
for i in range(0, num_rnn_layers):
if i == 0:
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
else:
self.rnn.append(
nn.LSTM(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
self.output_dim = layernorm_size
fc_input_size = layernorm_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i]
self.output_dim = fc_layers_size_list[i]
@property
def output_size(self):
return self.output_dim
def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
else:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)]
else:
init_state_list = [None] * self.num_rnn_layers
x, x_lens = self.conv(x, x_lens)
final_chunk_state_list = []
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
final_chunk_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_chunk_state_c_list = [
final_chunk_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_chunk_state_h_box = paddle.concat(
final_chunk_state_h_list, axis=0)
final_chunk_state_c_box = paddle.concat(
final_chunk_state_c_list, axis=0)
return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
eouts_chunk_list = []
eouts_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = None
chunk_state_c_box = None
final_state_h_box = None
final_state_c_box = None
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward(
x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
final_state_h_box = chunk_state_h_box
final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=4, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru)
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim]
dtype='float32'),
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'),
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32')
])
return static_model

@ -0,0 +1,67 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 32 # one gpu
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear #linear, mfcc, fbank
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1
fc_layers_size_list: 512,
use_gru: True
training:
n_epoch: 50
lr: 2e-3
lr_decay: 0.83 # 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 10

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path" echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1 exit -1
fi fi
@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path} --export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in export!" echo "Failed in export!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1 exit -1
fi fi
@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh bash local/download_lm_ch.sh
@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} --checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1 exit -1
fi fi
@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \ --device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} --output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -7,6 +7,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -31,10 +32,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi fi

@ -0,0 +1,67 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 20
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: False
training:
n_epoch: 50
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.9
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path" echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1 exit -1
fi fi
@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path} --export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in export!" echo "Failed in export!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1 exit -1
fi fi
@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh
@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} --checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1 exit -1
fi fi
@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \ --device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} --output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -6,6 +6,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=30 avg_num=30
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
@ -19,7 +20,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -29,10 +30,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi fi

@ -0,0 +1,69 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model:
num_conv_layers: 2
num_rnn_layers: 4
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: True
training:
n_epoch: 10
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path" echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1 exit -1
fi fi
@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_path_prefix=$2 ckpt_path_prefix=$2
jit_model_export_path=$3 jit_model_export_path=$3
model_type=$4
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \ --checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path} --export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in export!" echo "Failed in export!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1 exit -1
fi fi
@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi fi
config_path=$1 config_path=$1
ckpt_prefix=$2 ckpt_prefix=$2
model_type=$3
# download language model # download language model
bash local/download_lm_en.sh bash local/download_lm_en.sh
@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \ --nproc 1 \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.rsl \ --result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} --checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1 exit -1
fi fi
@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
model_type=$3
device=gpu device=gpu
if [ ${ngpu} == 0 ];then if [ ${ngpu} == 0 ];then
@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \ --device ${device} \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} --output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -7,6 +7,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
avg_num=1 avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -31,10 +32,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi fi

@ -16,7 +16,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase): class TestDeepSpeech2Model(unittest.TestCase):

@ -0,0 +1,186 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
max_len = 210
# (B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(max_len, size=self.batch_size)
audio_len[-1] = max_len
# (B, U)
text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size)
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
def test_ds2_1(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_2(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_3(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_4(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_5(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_6(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='bidirect',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_7(self):
use_gru = False
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
def test_ds2_8(self):
use_gru = True
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save