|
|
|
@ -11,7 +11,7 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Contains DeepSpeech2 model."""
|
|
|
|
|
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from pathlib import Path
|
|
|
|
@ -38,8 +38,6 @@ from deepspeech.utils import layer_tools
|
|
|
|
|
from deepspeech.utils import mp_tools
|
|
|
|
|
from deepspeech.utils.log import Autolog
|
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
|
|
|
|
|
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
@ -123,40 +121,20 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if hasattr(self, "train_loader"):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config.defrost()
|
|
|
|
|
assert (self.train_loader.collate_fn.feature_size ==
|
|
|
|
|
self.test_loader.collate_fn.feature_size)
|
|
|
|
|
assert (self.train_loader.collate_fn.vocab_size ==
|
|
|
|
|
self.test_loader.collate_fn.vocab_size)
|
|
|
|
|
config.model.feat_size = self.train_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.train_loader.collate_fn.vocab_size
|
|
|
|
|
config.freeze()
|
|
|
|
|
elif hasattr(self, "test_loader"):
|
|
|
|
|
config.defrost()
|
|
|
|
|
config.model.feat_size = self.test_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.test_loader.collate_fn.vocab_size
|
|
|
|
|
config.freeze()
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Please setup the dataloader first")
|
|
|
|
|
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=config.model.feat_size,
|
|
|
|
|
dict_size=config.model.dict_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
model = DeepSpeech2Model.from_config(config.model)
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=config.model.feat_size,
|
|
|
|
|
dict_size=config.model.dict_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
model = DeepSpeech2ModelOnline.from_config(config.model)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
|
if self.parallel:
|
|
|
|
@ -194,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
|
train_dataset,
|
|
|
|
@ -217,19 +198,29 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
collate_fn_test = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
|
print("feature_size", self.train_loader.collate_fn.feature_size)
|
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
|
dev_dataset,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=collate_fn_dev)
|
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
|
test_dataset,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=collate_fn_test)
|
|
|
|
|
logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
@ -371,20 +362,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
|
|
|
|
|
infer_model.eval()
|
|
|
|
|
feat_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
static_model = paddle.jit.to_static(
|
|
|
|
|
infer_model,
|
|
|
|
|
input_spec=[
|
|
|
|
|
paddle.static.InputSpec(
|
|
|
|
|
shape=[None, None, feat_dim],
|
|
|
|
|
dtype='float32'), # audio, [B,T,D]
|
|
|
|
|
paddle.static.InputSpec(shape=[None],
|
|
|
|
|
dtype='int64'), # audio_length, [B]
|
|
|
|
|
])
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
static_model = infer_model.export()
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
|
logger.info(f"Export code: {static_model.forward.code}")
|
|
|
|
|
paddle.jit.save(static_model, self.args.export_path)
|
|
|
|
|
|
|
|
|
@ -408,63 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
self.iteration = 0
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
|
model = DeepSpeech2ModelOnline(
|
|
|
|
|
feat_size=self.test_loader.collate_fn.feature_size,
|
|
|
|
|
dict_size=self.test_loader.collate_fn.vocab_size,
|
|
|
|
|
num_conv_layers=config.model.num_conv_layers,
|
|
|
|
|
num_rnn_layers=config.model.num_rnn_layers,
|
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
|
rnn_direction=config.model.rnn_direction,
|
|
|
|
|
num_fc_layers=config.model.num_fc_layers,
|
|
|
|
|
fc_layers_size_list=config.model.fc_layers_size_list,
|
|
|
|
|
use_gru=config.model.use_gru)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Wrong model type")
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config.defrost()
|
|
|
|
|
# return raw text
|
|
|
|
|
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
# filter test examples, will cause less examples, but no mismatch with training
|
|
|
|
|
# and can use large batch size , save training time, so filter test egs now.
|
|
|
|
|
# config.data.min_input_len = 0.0 # second
|
|
|
|
|
# config.data.max_input_len = float('inf') # second
|
|
|
|
|
# config.data.min_output_len = 0.0 # tokens
|
|
|
|
|
# config.data.max_output_len = float('inf') # tokens
|
|
|
|
|
# config.data.min_output_input_ratio = 0.00
|
|
|
|
|
# config.data.max_output_input_ratio = float('inf')
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
# return text ord id
|
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
|
test_dataset,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
|
|
|
|
|
|
def setup_output_dir(self):
|
|
|
|
|
"""Create a directory used for output.
|
|
|
|
|
"""
|
|
|
|
|