|
|
|
@ -27,7 +27,9 @@ from paddle import distributed as dist
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
|
from paddlespeech.s2t.io.dataloader import BatchDataLoader
|
|
|
|
|
from paddlespeech.s2t.io.dataset import ManifestDataset
|
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
|
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
|
|
|
|
@ -247,92 +249,103 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config.defrost()
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
|
|
|
|
|
|
if self.train:
|
|
|
|
|
# train/valid dataset, return token ids
|
|
|
|
|
config.data.manifest = config.data.train_manifest
|
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
collate_fn_train = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
num_replicas=None,
|
|
|
|
|
rank=None,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
|
else:
|
|
|
|
|
batch_sampler = SortagradBatchSampler(
|
|
|
|
|
train_dataset,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
self.train_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.train_manifest,
|
|
|
|
|
train_mode=True,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
|
dev_dataset,
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
|
mini_batch_size=self.args.nprocs,
|
|
|
|
|
batch_count='auto',
|
|
|
|
|
batch_bins=0,
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.
|
|
|
|
|
augmentation_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=config.collator.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
|
|
|
|
|
self.valid_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.dev_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
|
mini_batch_size=self.args.nprocs,
|
|
|
|
|
batch_count='auto',
|
|
|
|
|
batch_bins=0,
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.
|
|
|
|
|
augmentation_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=config.collator.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
else:
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
# filter test examples, will cause less examples, but no mismatch with training
|
|
|
|
|
# and can use large batch size , save training time, so filter test egs now.
|
|
|
|
|
config.data.min_input_len = 0.0 # second
|
|
|
|
|
config.data.max_input_len = float('inf') # second
|
|
|
|
|
config.data.min_output_len = 0.0 # tokens
|
|
|
|
|
config.data.max_output_len = float('inf') # tokens
|
|
|
|
|
config.data.min_output_input_ratio = 0.00
|
|
|
|
|
config.data.max_output_input_ratio = float('inf')
|
|
|
|
|
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
# return text ord id
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
|
test_dataset,
|
|
|
|
|
self.test_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.test_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config),
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
# return text token id
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
|
self.align_loader = DataLoader(
|
|
|
|
|
test_dataset,
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
|
mini_batch_size=1,
|
|
|
|
|
batch_count='auto',
|
|
|
|
|
batch_bins=0,
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.
|
|
|
|
|
augmentation_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=1,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
|
|
|
|
|
self.align_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.data.test_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config),
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
logger.info("Setup train/valid/test/align Dataloader!")
|
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
|
minibatches=0,
|
|
|
|
|
mini_batch_size=1,
|
|
|
|
|
batch_count='auto',
|
|
|
|
|
batch_bins=0,
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.collator.
|
|
|
|
|
augmentation_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=1,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1)
|
|
|
|
|
logger.info("Setup test/align Dataloader!")
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
|
model_conf = config.model
|
|
|
|
|
|
|
|
|
|
with UpdateConfig(model_conf):
|
|
|
|
|
model_conf.input_dim = self.train_loader.collate_fn.feature_size
|
|
|
|
|
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
|
|
|
|
|
if self.train:
|
|
|
|
|
model_conf.input_dim = self.train_loader.feat_dim
|
|
|
|
|
model_conf.output_dim = self.train_loader.vocab_size
|
|
|
|
|
else:
|
|
|
|
|
model_conf.input_dim = self.test_loader.feat_dim
|
|
|
|
|
model_conf.output_dim = self.test_loader.vocab_size
|
|
|
|
|
|
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
|
|
|
|
|
@ -341,6 +354,11 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
logger.info(f"{model}")
|
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
|
self.model = model
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|
|
|
|
|
|
if not self.train:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
train_config = config.training
|
|
|
|
|
optim_type = train_config.optim
|
|
|
|
@ -381,10 +399,9 @@ class U2Trainer(Trainer):
|
|
|
|
|
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
|
|
|
|
|
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
|
logger.info("Setup model/optimizer/lr_scheduler!")
|
|
|
|
|
logger.info("Setup optimizer/lr_scheduler!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class U2Tester(U2Trainer):
|
|
|
|
@ -419,14 +436,19 @@ class U2Tester(U2Trainer):
|
|
|
|
|
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab_filepath=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
self.vocab_list = self.text_feature.vocab_list
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
|
def id2token(self, texts, texts_len, text_feature):
|
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
|
trans = []
|
|
|
|
|
for text, n in zip(texts, texts_len):
|
|
|
|
|
n = n.numpy().item()
|
|
|
|
|
ids = text[:n]
|
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
|
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self,
|
|
|
|
@ -442,12 +464,11 @@ class U2Tester(U2Trainer):
|
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
text_feature = self.test_loader.collate_fn.text_feature
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
|
|
|
|
|
result_transcripts, result_tokenids = self.model.decode(
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
|
text_feature=text_feature,
|
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
@ -497,7 +518,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.test_loader.collate_fn.stride_ms
|
|
|
|
|
stride_ms = self.config.collator.stride_ms
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
num_frames = 0.0
|
|
|
|
@ -556,8 +577,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
def align(self):
|
|
|
|
|
ctc_utils.ctc_align(
|
|
|
|
|
self.model, self.align_loader, self.config.decoding.batch_size,
|
|
|
|
|
self.align_loader.collate_fn.stride_ms,
|
|
|
|
|
self.align_loader.collate_fn.vocab_list, self.args.result_file)
|
|
|
|
|
self.config.collator.stride_ms,
|
|
|
|
|
self.vocab_list, self.args.result_file)
|
|
|
|
|
|
|
|
|
|
def load_inferspec(self):
|
|
|
|
|
"""infer model and input spec.
|
|
|
|
|