|
|
@ -69,8 +69,8 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
batch_size = self.config.collator.batch_size
|
|
|
|
batch_size = self.config.batch_size
|
|
|
|
accum_grad = self.config.training.accum_grad
|
|
|
|
accum_grad = self.config.accum_grad
|
|
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
|
@ -133,7 +133,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
|
|
|
|
|
|
|
|
if (i + 1) % self.config.training.log_interval == 0:
|
|
|
|
if (i + 1) % self.config.log_interval == 0:
|
|
|
|
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
@ -154,16 +154,16 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
config = self.config.clone()
|
|
|
|
config = self.config.clone()
|
|
|
|
with UpdateConfig(config):
|
|
|
|
with UpdateConfig(config):
|
|
|
|
if self.train:
|
|
|
|
if self.train:
|
|
|
|
config.model.input_dim = self.train_loader.collate_fn.feature_size
|
|
|
|
config.input_dim = self.train_loader.collate_fn.feature_size
|
|
|
|
config.model.output_dim = self.train_loader.collate_fn.vocab_size
|
|
|
|
config.output_dim = self.train_loader.collate_fn.vocab_size
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
config.model.input_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
config.input_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
config.model.output_dim = self.test_loader.collate_fn.vocab_size
|
|
|
|
config.output_dim = self.test_loader.collate_fn.vocab_size
|
|
|
|
|
|
|
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
model = DeepSpeech2Model.from_config(config.model)
|
|
|
|
model = DeepSpeech2Model.from_config(config)
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
elif self.args.model_type == 'online':
|
|
|
|
model = DeepSpeech2ModelOnline.from_config(config.model)
|
|
|
|
model = DeepSpeech2ModelOnline.from_config(config)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
if self.parallel:
|
|
|
|
if self.parallel:
|
|
|
@ -177,17 +177,13 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
if not self.train:
|
|
|
|
if not self.train:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
grad_clip = ClipGradByGlobalNormWithLog(
|
|
|
|
grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip)
|
|
|
|
config.training.global_grad_clip)
|
|
|
|
|
|
|
|
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
|
|
|
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
|
|
|
learning_rate=config.training.lr,
|
|
|
|
learning_rate=config.lr, gamma=config.lr_decay, verbose=True)
|
|
|
|
gamma=config.training.lr_decay,
|
|
|
|
|
|
|
|
verbose=True)
|
|
|
|
|
|
|
|
optimizer = paddle.optimizer.Adam(
|
|
|
|
optimizer = paddle.optimizer.Adam(
|
|
|
|
learning_rate=lr_scheduler,
|
|
|
|
learning_rate=lr_scheduler,
|
|
|
|
parameters=model.parameters(),
|
|
|
|
parameters=model.parameters(),
|
|
|
|
weight_decay=paddle.regularizer.L2Decay(
|
|
|
|
weight_decay=paddle.regularizer.L2Decay(config.weight_decay),
|
|
|
|
config.training.weight_decay),
|
|
|
|
|
|
|
|
grad_clip=grad_clip)
|
|
|
|
grad_clip=grad_clip)
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
@ -198,66 +194,67 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
config.defrost()
|
|
|
|
config.defrost()
|
|
|
|
if self.train:
|
|
|
|
if self.train:
|
|
|
|
# train
|
|
|
|
# train
|
|
|
|
config.data.manifest = config.data.train_manifest
|
|
|
|
config.manifest = config.train_manifest
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
if self.parallel:
|
|
|
|
if self.parallel:
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
num_replicas=None,
|
|
|
|
num_replicas=None,
|
|
|
|
rank=None,
|
|
|
|
rank=None,
|
|
|
|
shuffle=True,
|
|
|
|
shuffle=True,
|
|
|
|
drop_last=True,
|
|
|
|
drop_last=True,
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
sortagrad=config.sortagrad,
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
shuffle_method=config.shuffle_method)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
batch_sampler = SortagradBatchSampler(
|
|
|
|
batch_sampler = SortagradBatchSampler(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
shuffle=True,
|
|
|
|
shuffle=True,
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
drop_last=True,
|
|
|
|
drop_last=True,
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
sortagrad=config.sortagrad,
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
shuffle_method=config.shuffle_method)
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
config.keep_transcription_text = False
|
|
|
|
collate_fn_train = SpeechCollator.from_config(config)
|
|
|
|
collate_fn_train = SpeechCollator.from_config(config)
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.num_workers)
|
|
|
|
|
|
|
|
|
|
|
|
# dev
|
|
|
|
# dev
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
config.manifest = config.dev_manifest
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
config.augmentation_config = ""
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
config.keep_transcription_text = False
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
dev_dataset,
|
|
|
|
dev_dataset,
|
|
|
|
batch_size=int(config.collator.batch_size),
|
|
|
|
batch_size=int(config.batch_size),
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.num_workers)
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# test
|
|
|
|
# test
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
config.manifest = config.test_manifest
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
config.augmentation_config = ""
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
config.keep_transcription_text = True
|
|
|
|
collate_fn_test = SpeechCollator.from_config(config)
|
|
|
|
collate_fn_test = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
decode_batch_size = config.get('decode', dict()).get(
|
|
|
|
|
|
|
|
'decode_batch_size', 1)
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
test_dataset,
|
|
|
|
test_dataset,
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
batch_size=decode_batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn_test,
|
|
|
|
collate_fn=collate_fn_test,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.num_workers)
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -286,7 +283,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
super().__init__(config, args)
|
|
|
|
super().__init__(config, args)
|
|
|
|
self._text_featurizer = TextFeaturizer(
|
|
|
|
self._text_featurizer = TextFeaturizer(
|
|
|
|
unit_type=config.collator.unit_type, vocab=None)
|
|
|
|
unit_type=config.unit_type, vocab=None)
|
|
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
""" ord() id to chr() chr """
|
|
|
@ -304,17 +301,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
texts,
|
|
|
|
texts,
|
|
|
|
texts_len,
|
|
|
|
texts_len,
|
|
|
|
fout=None):
|
|
|
|
fout=None):
|
|
|
|
cfg = self.config.decoding
|
|
|
|
decode_cfg = self.config.decode
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
vocab_list = self.test_loader.collate_fn.vocab_list
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
|
|
|
|
|
|
|
result_transcripts = self.compute_result_transcripts(audio, audio_len,
|
|
|
|
result_transcripts = self.compute_result_transcripts(
|
|
|
|
vocab_list, cfg)
|
|
|
|
audio, audio_len, vocab_list, decode_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
result_transcripts):
|
|
|
|
result_transcripts):
|
|
|
@ -327,29 +324,31 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
logger.info(f"Utt: {utt}")
|
|
|
|
logger.info(f"Utt: {utt}")
|
|
|
|
logger.info(f"Ref: {target}")
|
|
|
|
logger.info(f"Ref: {target}")
|
|
|
|
logger.info(f"Hyp: {result}")
|
|
|
|
logger.info(f"Hyp: {result}")
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
logger.info(
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
"Current error rate [%s] = %f" %
|
|
|
|
|
|
|
|
(decode_cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
return dict(
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
len_refs=len_refs,
|
|
|
|
len_refs=len_refs,
|
|
|
|
num_ins=num_ins,
|
|
|
|
num_ins=num_ins,
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
error_rate_type=cfg.error_rate_type)
|
|
|
|
error_rate_type=decode_cfg.error_rate_type)
|
|
|
|
|
|
|
|
|
|
|
|
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
|
|
|
|
def compute_result_transcripts(self, audio, audio_len, vocab_list,
|
|
|
|
|
|
|
|
decode_cfg):
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
audio_len,
|
|
|
|
audio_len,
|
|
|
|
vocab_list,
|
|
|
|
vocab_list,
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
decoding_method=decode_cfg.decoding_method,
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
lang_model_path=decode_cfg.lang_model_path,
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
beam_alpha=decode_cfg.alpha,
|
|
|
|
beam_beta=cfg.beta,
|
|
|
|
beam_beta=decode_cfg.beta,
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
beam_size=decode_cfg.beam_size,
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
cutoff_prob=decode_cfg.cutoff_prob,
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
cutoff_top_n=decode_cfg.cutoff_top_n,
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
num_processes=decode_cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
|
|
return result_transcripts
|
|
|
|
return result_transcripts
|
|
|
|
|
|
|
|
|
|
|
@ -358,7 +357,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
def test(self):
|
|
|
|
def test(self):
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
cfg = self.config
|
|
|
|
|
|
|
|
error_rate_type = None
|
|
|
|
error_rate_type = None
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
|
|
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
|
|
@ -412,11 +410,10 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
if self.args.enable_auto_log is True:
|
|
|
|
if self.args.enable_auto_log is True:
|
|
|
|
from paddlespeech.s2t.utils.log import Autolog
|
|
|
|
from paddlespeech.s2t.utils.log import Autolog
|
|
|
|
self.autolog = Autolog(
|
|
|
|
self.autolog = Autolog(
|
|
|
|
batch_size=self.config.decoding.batch_size,
|
|
|
|
batch_size=self.config.decode.decode_batch_size,
|
|
|
|
model_name="deepspeech2",
|
|
|
|
model_name="deepspeech2",
|
|
|
|
model_precision="fp32").getlog()
|
|
|
|
model_precision="fp32").getlog()
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
cfg = self.config
|
|
|
|
|
|
|
|
error_rate_type = None
|
|
|
|
error_rate_type = None
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
|
|
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
|
|
@ -441,7 +438,8 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
if self.args.enable_auto_log is True:
|
|
|
|
if self.args.enable_auto_log is True:
|
|
|
|
self.autolog.report()
|
|
|
|
self.autolog.report()
|
|
|
|
|
|
|
|
|
|
|
|
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
|
|
|
|
def compute_result_transcripts(self, audio, audio_len, vocab_list,
|
|
|
|
|
|
|
|
decode_cfg):
|
|
|
|
if self.args.model_type == "online":
|
|
|
|
if self.args.model_type == "online":
|
|
|
|
output_probs, output_lens = self.static_forward_online(audio,
|
|
|
|
output_probs, output_lens = self.static_forward_online(audio,
|
|
|
|
audio_len)
|
|
|
|
audio_len)
|
|
|
@ -454,13 +452,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
self.predictor.clear_intermediate_tensor()
|
|
|
|
self.predictor.clear_intermediate_tensor()
|
|
|
|
self.predictor.try_shrink_memory()
|
|
|
|
self.predictor.try_shrink_memory()
|
|
|
|
|
|
|
|
|
|
|
|
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
|
|
|
|
self.model.decoder.init_decode(decode_cfg.alpha, decode_cfg.beta,
|
|
|
|
vocab_list, cfg.decoding_method)
|
|
|
|
decode_cfg.lang_model_path, vocab_list,
|
|
|
|
|
|
|
|
decode_cfg.decoding_method)
|
|
|
|
|
|
|
|
|
|
|
|
result_transcripts = self.model.decoder.decode_probs(
|
|
|
|
result_transcripts = self.model.decoder.decode_probs(
|
|
|
|
output_probs, output_lens, vocab_list, cfg.decoding_method,
|
|
|
|
output_probs, output_lens, vocab_list, decode_cfg.decoding_method,
|
|
|
|
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
|
|
|
|
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
|
|
|
|
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
|
|
|
|
decode_cfg.beam_size, decode_cfg.cutoff_prob,
|
|
|
|
|
|
|
|
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
|
|
|
|
#replace the <space> with ' '
|
|
|
|
#replace the <space> with ' '
|
|
|
|
result_transcripts = [
|
|
|
|
result_transcripts = [
|
|
|
|
self._text_featurizer.detokenize(sentence)
|
|
|
|
self._text_featurizer.detokenize(sentence)
|
|
|
@ -531,12 +531,10 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
num_chunk = int(num_chunk)
|
|
|
|
num_chunk = int(num_chunk)
|
|
|
|
|
|
|
|
|
|
|
|
chunk_state_h_box = np.zeros(
|
|
|
|
chunk_state_h_box = np.zeros(
|
|
|
|
(self.config.model.num_rnn_layers, 1,
|
|
|
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
|
|
|
self.config.model.rnn_layer_size),
|
|
|
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
dtype=x.dtype)
|
|
|
|
chunk_state_c_box = np.zeros(
|
|
|
|
chunk_state_c_box = np.zeros(
|
|
|
|
(self.config.model.num_rnn_layers, 1,
|
|
|
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
|
|
|
self.config.model.rnn_layer_size),
|
|
|
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
input_names = self.predictor.get_input_names()
|
|
|
|
input_names = self.predictor.get_input_names()
|
|
|
|