|
|
|
@ -77,7 +77,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
|
train_conf = self.config.training
|
|
|
|
|
train_conf = self.config
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
# forward
|
|
|
|
@ -120,7 +120,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
report(k, v)
|
|
|
|
|
report("batch_size", self.config.collator.batch_size)
|
|
|
|
|
report("batch_size", self.config.batch_size)
|
|
|
|
|
report("accum", train_conf.accum_grad)
|
|
|
|
|
report("step_cost", iteration_time)
|
|
|
|
|
|
|
|
|
@ -153,7 +153,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
if ctc_loss:
|
|
|
|
|
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
|
|
|
|
|
|
|
|
|
if (i + 1) % self.config.training.log_interval == 0:
|
|
|
|
|
if (i + 1) % self.config.log_interval == 0:
|
|
|
|
|
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
|
|
|
|
|
|
|
|
@ -182,7 +182,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.before_train()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
|
while self.epoch < self.config.n_epoch:
|
|
|
|
|
with Timer("Epoch-Train Time Cost: {}"):
|
|
|
|
|
self.model.train()
|
|
|
|
|
try:
|
|
|
|
@ -214,8 +214,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
k.split(',')) == 2 else ""
|
|
|
|
|
msg += ","
|
|
|
|
|
msg = msg[:-1] # remove the last ","
|
|
|
|
|
if (batch_index + 1
|
|
|
|
|
) % self.config.training.log_interval == 0:
|
|
|
|
|
if (batch_index + 1) % self.config.log_interval == 0:
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -252,29 +251,29 @@ class U2Trainer(Trainer):
|
|
|
|
|
if self.train:
|
|
|
|
|
# train/valid dataset, return token ids
|
|
|
|
|
self.train_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.train_manifest,
|
|
|
|
|
json_file=config.train_manifest,
|
|
|
|
|
train_mode=True,
|
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
maxlen_in=config.collator.maxlen_in,
|
|
|
|
|
maxlen_out=config.collator.maxlen_out,
|
|
|
|
|
minibatches=config.collator.minibatches,
|
|
|
|
|
sortagrad=config.sortagrad,
|
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
|
maxlen_in=config.maxlen_in,
|
|
|
|
|
maxlen_out=config.maxlen_out,
|
|
|
|
|
minibatches=config.minibatches,
|
|
|
|
|
mini_batch_size=self.args.ngpu,
|
|
|
|
|
batch_count=config.collator.batch_count,
|
|
|
|
|
batch_bins=config.collator.batch_bins,
|
|
|
|
|
batch_frames_in=config.collator.batch_frames_in,
|
|
|
|
|
batch_frames_out=config.collator.batch_frames_out,
|
|
|
|
|
batch_frames_inout=config.collator.batch_frames_inout,
|
|
|
|
|
preprocess_conf=config.collator.augmentation_config,
|
|
|
|
|
n_iter_processes=config.collator.num_workers,
|
|
|
|
|
batch_count=config.batch_count,
|
|
|
|
|
batch_bins=config.batch_bins,
|
|
|
|
|
batch_frames_in=config.batch_frames_in,
|
|
|
|
|
batch_frames_out=config.batch_frames_out,
|
|
|
|
|
batch_frames_inout=config.batch_frames_inout,
|
|
|
|
|
preprocess_conf=config.augmentation_config,
|
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
|
|
|
|
|
self.valid_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.dev_manifest,
|
|
|
|
|
json_file=config.dev_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
batch_size=config.batch_size,
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
@ -284,18 +283,18 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.augmentation_config,
|
|
|
|
|
n_iter_processes=config.collator.num_workers,
|
|
|
|
|
preprocess_conf=config.augmentation_config,
|
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
else:
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
|
self.test_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.test_manifest,
|
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
batch_size=config.decoding.decode_batch_size,
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
@ -305,16 +304,16 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.augmentation_config,
|
|
|
|
|
preprocess_conf=config.augmentation_config,
|
|
|
|
|
n_iter_processes=1,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
|
|
|
|
|
self.align_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.test_manifest,
|
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
batch_size=config.decoding.decode_batch_size,
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
@ -324,7 +323,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.augmentation_config,
|
|
|
|
|
preprocess_conf=config.augmentation_config,
|
|
|
|
|
n_iter_processes=1,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
@ -332,7 +331,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
model_conf = config.model
|
|
|
|
|
model_conf = config
|
|
|
|
|
|
|
|
|
|
with UpdateConfig(model_conf):
|
|
|
|
|
if self.train:
|
|
|
|
@ -355,7 +354,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
if not self.train:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
train_config = config.training
|
|
|
|
|
train_config = config
|
|
|
|
|
optim_type = train_config.optim
|
|
|
|
|
optim_conf = train_config.optim_conf
|
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
@ -375,7 +374,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
config,
|
|
|
|
|
parameters,
|
|
|
|
|
lr_scheduler=None, ):
|
|
|
|
|
train_config = config.training
|
|
|
|
|
train_config = config
|
|
|
|
|
optim_type = train_config.optim
|
|
|
|
|
optim_conf = train_config.optim_conf
|
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
@ -415,7 +414,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
|
|
|
|
|
num_proc_bsearch=8, # # of CPUs for beam search.
|
|
|
|
|
beam_size=10, # Beam search width.
|
|
|
|
|
batch_size=16, # decoding batch size
|
|
|
|
|
decode_batch_size=16, # decoding batch size
|
|
|
|
|
ctc_weight=0.0, # ctc weight for attention rescoring decode mode.
|
|
|
|
|
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
|
|
|
|
|
# <0: for decoding, use full chunk.
|
|
|
|
@ -432,9 +431,9 @@ class U2Tester(U2Trainer):
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
unit_type=self.config.unit_type,
|
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
|
self.vocab_list = self.text_feature.vocab_list
|
|
|
|
|
|
|
|
|
|
def id2token(self, texts, texts_len, text_feature):
|
|
|
|
@ -453,10 +452,10 @@ class U2Tester(U2Trainer):
|
|
|
|
|
texts,
|
|
|
|
|
texts_len,
|
|
|
|
|
fout=None):
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
decode_config = self.config.decoding
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
|
error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
|
|
|
|
@ -464,12 +463,12 @@ class U2Tester(U2Trainer):
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
ctc_weight=cfg.ctc_weight,
|
|
|
|
|
decoding_chunk_size=cfg.decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming=cfg.simulate_streaming)
|
|
|
|
|
decoding_method=decode_config.decoding_method,
|
|
|
|
|
beam_size=decode_config.beam_size,
|
|
|
|
|
ctc_weight=decode_config.ctc_weight,
|
|
|
|
|
decoding_chunk_size=decode_config.decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming=decode_config.simulate_streaming)
|
|
|
|
|
decode_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
for utt, target, result, rec_tids in zip(
|
|
|
|
@ -488,15 +487,15 @@ class U2Tester(U2Trainer):
|
|
|
|
|
logger.info(f"Utt: {utt}")
|
|
|
|
|
logger.info(f"Ref: {target}")
|
|
|
|
|
logger.info(f"Hyp: {result}")
|
|
|
|
|
logger.info("One example error rate [%s] = %f" %
|
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
logger.info("One example error rate [%s] = %f" % (
|
|
|
|
|
decode_config.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
|
len_refs=len_refs,
|
|
|
|
|
num_ins=num_ins, # num examples
|
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
|
error_rate_type=cfg.error_rate_type,
|
|
|
|
|
error_rate_type=decode_config.error_rate_type,
|
|
|
|
|
num_frames=audio_len.sum().numpy().item(),
|
|
|
|
|
decode_time=decode_time)
|
|
|
|
|
|
|
|
|
@ -507,7 +506,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.config.collator.stride_ms
|
|
|
|
|
stride_ms = self.config.stride_ms
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
num_frames = 0.0
|
|
|
|
@ -558,15 +557,15 @@ class U2Tester(U2Trainer):
|
|
|
|
|
"ref_len":
|
|
|
|
|
len_refs,
|
|
|
|
|
"decode_method":
|
|
|
|
|
self.config.decoding.decoding_method,
|
|
|
|
|
self.config.decoding_method,
|
|
|
|
|
})
|
|
|
|
|
f.write(data + '\n')
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def align(self):
|
|
|
|
|
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
|
|
|
|
|
self.config.decoding.batch_size,
|
|
|
|
|
self.config.collator.stride_ms, self.vocab_list,
|
|
|
|
|
self.config.decoding.decode_batch_size,
|
|
|
|
|
self.config.stride_ms, self.vocab_list,
|
|
|
|
|
self.args.result_file)
|
|
|
|
|
|
|
|
|
|
def load_inferspec(self):
|
|
|
|
@ -577,10 +576,10 @@ class U2Tester(U2Trainer):
|
|
|
|
|
List[paddle.static.InputSpec]: input spec.
|
|
|
|
|
"""
|
|
|
|
|
from paddlespeech.s2t.models.u2 import U2InferModel
|
|
|
|
|
infer_model = U2InferModel.from_pretrained(self.test_loader,
|
|
|
|
|
self.config.model.clone(),
|
|
|
|
|
infer_model = U2InferModel.from_pretrained(self.train_loader,
|
|
|
|
|
self.config.clone(),
|
|
|
|
|
self.args.checkpoint_path)
|
|
|
|
|
feat_dim = self.test_loader.feat_dim
|
|
|
|
|
feat_dim = self.train_loader.feat_dim
|
|
|
|
|
input_spec = [
|
|
|
|
|
paddle.static.InputSpec(shape=[1, None, feat_dim],
|
|
|
|
|
dtype='float32'), # audio, [B,T,D]
|
|
|
|
|