[st] Distributed sampler and new dataloader with MIMO (#1239)

* update timit result, test=doc_fix

* result update

* fix bug

* add triplet loader

* empty preprocess file

* sync to u2, updating

* sync to u2 config

* fix bugs

* code refine

* update config

* customize decoding batch size

* update optimizer and lr scheduler

* minor

* minor

* minor

* fix bugs of refs

* minor

* distributed sampler

* minor

* refine the loader
pull/1243/head
Junkun Chen 3 years ago committed by GitHub
parent fbe3c05137
commit 420709e5ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
data: data:
train_manifest: data/manifest.train.tiny train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 0.05 # second min_input_len: 0.05 # second
@ -15,8 +15,10 @@ collator:
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: data/lang_char/bpe_unigram_8000 spm_model_prefix: data/lang_char/bpe_unigram_8000
mean_std_filepath: "" mean_std_filepath: ""
# augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 10 batch_size: 16
maxlen_in: 5 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
@ -78,13 +80,13 @@ training:
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.004 lr: 2.5
weight_decay: 1e-06 weight_decay: 1e-06
scheduler: warmuplr scheduler: noam
scheduler_conf: scheduler_conf:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 5 log_interval: 50
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5
@ -97,6 +99,7 @@ decoding:
alpha: 2.5 alpha: 2.5
beta: 0.3 beta: 0.3
beam_size: 10 beam_size: 10
word_reward: 0.7
cutoff_prob: 1.0 cutoff_prob: 1.0
cutoff_top_n: 0 cutoff_top_n: 0
num_proc_bsearch: 8 num_proc_bsearch: 8
@ -107,3 +110,5 @@ decoding:
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False. simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -15,8 +15,10 @@ collator:
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: data/lang_char/bpe_unigram_8000 spm_model_prefix: data/lang_char/bpe_unigram_8000
mean_std_filepath: "" mean_std_filepath: ""
# augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 10 batch_size: 16
maxlen_in: 5 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80

@ -13,14 +13,12 @@ ckpt_prefix=$2
for type in fullsentence; do for type in fullsentence; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=32
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \ --opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -12,5 +12,5 @@
## Transformer ## Transformer
| Model | Params | Config | Val loss | Char-BLEU | | Model | Params | Config | Val loss | Char-BLEU |
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| FAT + Transformer+ASR MTL | 50.26M | conf/transformer_mtl_noam.yaml | 62.86 | 19.45 | | FAT + Transformer+ASR MTL | 50.26M | conf/transformer_mtl_noam.yaml | 69.91 | 20.26 |
| FAT + Transformer+ASR MTL with word reward | 50.26M | conf/transformer_mtl_noam.yaml | 62.86 | 20.80 | | FAT + Transformer+ASR MTL with word reward | 50.26M | conf/transformer_mtl_noam.yaml | 62.86 | 20.80 |

@ -1,39 +1,33 @@
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
data: data:
train_manifest: data/manifest.train.tiny train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 5.0 # frame
max_input_len: 3000.0 # frame
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.01
max_output_input_ratio: 20.0
collator: collator:
vocab_filepath: data/lang_char/vocab.txt vocab_filepath: data/lang_char/ted_en_zh_bpe8000.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: data/lang_char/bpe_unigram_8000 spm_model_prefix: data/lang_char/ted_en_zh_bpe8000
mean_std_filepath: "" mean_std_filepath: ""
# augmentation_config: conf/augmentation.json # augmentation_config: conf/augmentation.json
batch_size: 10 batch_size: 20
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 83 feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
use_dB_normalization: True sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
target_dB: -20 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
random_seed: 0 maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
keep_transcription_text: False minibatches: 0 # for debug
sortagrad: True batch_count: auto
shuffle_method: batch_shuffle batch_bins: 0
num_workers: 2 batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config:
num_workers: 0
subsampling_factor: 1
num_encs: 1
# network architecture # network architecture
@ -73,18 +67,18 @@ model:
training: training:
n_epoch: 20 n_epoch: 40
accum_grad: 2 accum_grad: 2
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.004 lr: 2.5
weight_decay: 1e-06 weight_decay: 0.
scheduler: warmuplr scheduler: noam
scheduler_conf: scheduler_conf:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 5 log_interval: 50
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5
@ -107,4 +101,4 @@ decoding:
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False. simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -3,12 +3,6 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 5.0 # frame
max_input_len: 3000.0 # frame
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.01
max_output_input_ratio: 20.0
collator: collator:
vocab_filepath: data/lang_char/ted_en_zh_bpe8000.txt vocab_filepath: data/lang_char/ted_en_zh_bpe8000.txt
@ -16,24 +10,24 @@ collator:
spm_model_prefix: data/lang_char/ted_en_zh_bpe8000 spm_model_prefix: data/lang_char/ted_en_zh_bpe8000
mean_std_filepath: "" mean_std_filepath: ""
# augmentation_config: conf/augmentation.json # augmentation_config: conf/augmentation.json
batch_size: 10 batch_size: 20
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 83 feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
use_dB_normalization: True sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
target_dB: -20 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
random_seed: 0 maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
keep_transcription_text: False minibatches: 0 # for debug
sortagrad: True batch_count: auto
shuffle_method: batch_shuffle batch_bins: 0
num_workers: 2 batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
augmentation_config:
num_workers: 0
subsampling_factor: 1
num_encs: 1
# network architecture # network architecture
@ -73,18 +67,18 @@ model:
training: training:
n_epoch: 20 n_epoch: 40
accum_grad: 2 accum_grad: 2
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam
optim_conf: optim_conf:
lr: 2.5 lr: 2.5
weight_decay: 1e-06 weight_decay: 0.
scheduler: noam scheduler: noam
scheduler_conf: scheduler_conf:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 5 log_interval: 50
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5

@ -13,14 +13,12 @@ ckpt_prefix=$2
for type in fullsentence; do for type in fullsentence; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=32
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \ --result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \ --opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -16,6 +16,7 @@ import json
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
@ -23,21 +24,18 @@ import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import TripletSpeechCollator from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
from paddlespeech.s2t.models.u2_st import U2STModel from paddlespeech.s2t.models.u2_st import U2STModel
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.scheduler import WarmupLR from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils import bleu_score from paddlespeech.s2t.utils import bleu_score
from paddlespeech.s2t.utils import ctc_utils
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
@ -96,6 +94,8 @@ class U2STTrainer(Trainer):
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
if st_loss:
losses_np['st_loss'] = float(st_loss)
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
@ -125,6 +125,12 @@ class U2STTrainer(Trainer):
iteration_time = time.time() - start iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)
report("batch_size", self.config.collator.batch_size)
report("accum", train_conf.accum_grad)
report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.log_interval == 0: if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
@ -204,16 +210,34 @@ class U2STTrainer(Trainer):
data_start_time = time.time() data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader): for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train:"
msg += "epoch: {}, ".format(self.epoch) observation = OrderedDict()
msg += "step: {}, ".format(self.iteration) with ObsScope(observation):
msg += "batch : {}/{}, ".format(batch_index + 1, report("Rank", dist.get_rank())
len(self.train_loader)) report("epoch", self.epoch)
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) report('step', self.iteration)
msg += "data time: {:>.3f}s, ".format(dataload_time) report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
data_start_time = time.time() report('iter', batch_index + 1)
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips,sent./sec'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k.split(',')[0]}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += f" {k.split(',')[1]}" if len(
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1
) % self.config.training.log_interval == 0:
logger.info(msg)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
@ -244,95 +268,87 @@ class U2STTrainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
if config.model.model_conf.asr_weight > 0.:
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
collate_fn_train = Collator.from_config(config) load_transcript = True if config.model.model_conf.asr_weight > 0 else False
config.collator.augmentation_config = ""
collate_fn_dev = Collator.from_config(config)
if self.parallel: if self.train:
batch_sampler = SortagradDistributedBatchSampler( # train/valid dataset, return token ids
train_dataset, self.train_loader = BatchDataLoader(
json_file=config.data.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, maxlen_in=config.collator.maxlen_in,
rank=None, maxlen_out=config.collator.maxlen_out,
shuffle=True, minibatches=0,
drop_last=True, mini_batch_size=1,
sortagrad=config.collator.sortagrad, batch_count='auto',
shuffle_method=config.collator.shuffle_method) batch_bins=0,
else: batch_frames_in=0,
batch_sampler = SortagradBatchSampler( batch_frames_out=0,
train_dataset, batch_frames_inout=0,
shuffle=True, preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
drop_last=True, maxlen_in=float('inf'),
sortagrad=config.collator.sortagrad, maxlen_out=float('inf'),
shuffle_method=config.collator.shuffle_method) minibatches=0,
self.train_loader = DataLoader( mini_batch_size=1,
train_dataset, batch_count='auto',
batch_sampler=batch_sampler, batch_bins=0,
collate_fn=collate_fn_train, batch_frames_in=0,
num_workers=config.collator.num_workers, ) batch_frames_out=0,
self.valid_loader = DataLoader( batch_frames_inout=0,
dev_dataset, preprocess_conf=config.collator.
batch_size=config.collator.batch_size, augmentation_config, # aug will be off when train_mode=False
shuffle=False, n_iter_processes=config.collator.num_workers,
drop_last=False, subsampling_factor=1,
collate_fn=collate_fn_dev, load_aux_output=load_transcript,
num_workers=config.collator.num_workers, ) num_encs=1)
logger.info("Setup train/valid Dataloader!")
# test dataset, return raw text else:
config.data.manifest = config.data.test_manifest # test dataset, return raw text
# filter test examples, will cause less examples, but no mismatch with training self.test_loader = BatchDataLoader(
# and can use large batch size , save training time, so filter test egs now. json_file=config.data.test_manifest,
# config.data.min_input_len = 0.0 # second train_mode=False,
# config.data.max_input_len = float('inf') # second sortagrad=False,
# config.data.min_output_len = 0.0 # tokens batch_size=config.decoding.batch_size,
# config.data.max_output_len = float('inf') # tokens maxlen_in=float('inf'),
# config.data.min_output_input_ratio = 0.00 maxlen_out=float('inf'),
# config.data.max_output_input_ratio = float('inf') minibatches=0,
test_dataset = ManifestDataset.from_config(config) mini_batch_size=1,
# return text ord id batch_count='auto',
config.collator.keep_transcription_text = True batch_bins=0,
config.collator.augmentation_config = "" batch_frames_in=0,
self.test_loader = DataLoader( batch_frames_out=0,
test_dataset, batch_frames_inout=0,
batch_size=config.decoding.batch_size, preprocess_conf=config.collator.
shuffle=False, augmentation_config, # aug will be off when train_mode=False
drop_last=False, n_iter_processes=config.collator.num_workers,
collate_fn=TestCollator.from_config(config), subsampling_factor=1,
num_workers=config.collator.num_workers, ) num_encs=1)
# return text token id
config.collator.keep_transcription_text = False logger.info("Setup test Dataloader!")
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=TestCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config.model
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size if self.train:
model_conf.output_dim = self.train_loader.collate_fn.vocab_size model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = U2STModel.from_config(model_conf) model = U2STModel.from_config(model_conf)
@ -348,35 +364,38 @@ class U2STTrainer(Trainer):
scheduler_type = train_config.scheduler scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf scheduler_conf = train_config.scheduler_conf
if scheduler_type == 'expdecaylr': scheduler_args = {
lr_scheduler = paddle.optimizer.lr.ExponentialDecay( "learning_rate": optim_conf.lr,
learning_rate=optim_conf.lr, "verbose": False,
gamma=scheduler_conf.lr_decay, "warmup_steps": scheduler_conf.warmup_steps,
verbose=False) "gamma": scheduler_conf.lr_decay,
elif scheduler_type == 'warmuplr': "d_model": model_conf.encoder_conf.output_size,
lr_scheduler = WarmupLR( }
learning_rate=optim_conf.lr, lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
warmup_steps=scheduler_conf.warmup_steps, scheduler_args)
verbose=False)
elif scheduler_type == 'noam': def optimizer_args(
lr_scheduler = paddle.optimizer.lr.NoamDecay( config,
learning_rate=optim_conf.lr, parameters,
d_model=model_conf.encoder_conf.output_size, lr_scheduler=None, ):
warmup_steps=scheduler_conf.warmup_steps, train_config = config.training
verbose=False) optim_type = train_config.optim
else: optim_conf = train_config.optim_conf
raise ValueError(f"Not support scheduler: {scheduler_type}") scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip) return {
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay) "grad_clip": train_config.global_grad_clip,
if optim_type == 'adam': "weight_decay": optim_conf.weight_decay,
optimizer = paddle.optimizer.Adam( "learning_rate": lr_scheduler
learning_rate=lr_scheduler, if lr_scheduler else optim_conf.lr,
parameters=model.parameters(), "parameters": parameters,
weight_decay=weight_decay, "epsilon": 1e-9 if optim_type == 'noam' else None,
grad_clip=grad_clip) "beta1": 0.9 if optim_type == 'noam' else None,
else: "beat2": 0.98 if optim_type == 'noam' else None,
raise ValueError(f"Not support optim: {optim_type}") }
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
@ -416,26 +435,30 @@ class U2STTester(U2STTrainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
def ordid2token(self, texts, texts_len): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
trans = [] trans = []
for text, n in zip(texts, texts_len): for text, n in zip(texts, texts_len):
n = n.numpy().item() n = n.numpy().item()
ids = text[:n] ids = text[:n]
trans.append(''.join([chr(i) for i in ids])) trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans return trans
def translate(self, audio, audio_len): def translate(self, audio, audio_len):
""""E2E translation from extracted audio feature""" """"E2E translation from extracted audio feature"""
cfg = self.config.decoding cfg = self.config.decoding
text_feature = self.test_loader.collate_fn.text_feature
self.model.eval() self.model.eval()
hyps = self.model.decode( hyps = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
word_reward=cfg.word_reward, word_reward=cfg.word_reward,
@ -456,23 +479,20 @@ class U2STTester(U2STTrainer):
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
refs = [ refs = self.id2token(texts, texts_len, self.text_feature)
"".join(chr(t) for t in text[:text_len])
for text, text_len in zip(texts, texts_len)
]
hyps = self.model.decode( hyps = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
word_reward=cfg.word_reward, word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result in zip(utts, refs, hyps): for utt, target, result in zip(utts, refs, hyps):
@ -505,7 +525,7 @@ class U2STTester(U2STTrainer):
cfg = self.config.decoding cfg = self.config.decoding
bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
stride_ms = self.test_loader.collate_fn.stride_ms stride_ms = self.config.collator.stride_ms
hyps, refs = [], [] hyps, refs = [], []
len_refs, num_ins = 0, 0 len_refs, num_ins = 0, 0
num_frames = 0.0 num_frames = 0.0
@ -522,7 +542,7 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu)) logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
msg = "Test: " msg = "Test: "
@ -553,13 +573,6 @@ class U2STTester(U2STTrainer):
}) })
f.write(data + '\n') f.write(data + '\n')
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.config.collator.stride_ms, self.vocab_list,
self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -567,11 +580,11 @@ class U2STTester(U2STTrainer):
nn.Layer: inference model nn.Layer: inference model
List[paddle.static.InputSpec]: input spec. List[paddle.static.InputSpec]: input spec.
""" """
from paddlespeech.s2t.models.u2 import U2InferModel from paddlespeech.s2t.models.u2_st import U2STInferModel
infer_model = U2InferModel.from_pretrained(self.test_loader, infer_model = U2STInferModel.from_pretrained(self.test_loader,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.feat_dim
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, feat_dim], paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D] dtype='float32'), # audio, [B,T,D]

@ -31,11 +31,17 @@ class CustomConverter():
""" """
def __init__(self, subsampling_factor=1, dtype=np.float32): def __init__(self,
subsampling_factor=1,
dtype=np.float32,
load_aux_input=False,
load_aux_output=False):
"""Construct a CustomConverter object.""" """Construct a CustomConverter object."""
self.subsampling_factor = subsampling_factor self.subsampling_factor = subsampling_factor
self.ignore_id = -1 self.ignore_id = -1
self.dtype = dtype self.dtype = dtype
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
def __call__(self, batch): def __call__(self, batch):
"""Transform a batch and send it to a device. """Transform a batch and send it to a device.
@ -49,34 +55,48 @@ class CustomConverter():
""" """
# batch should be located in list # batch should be located in list
assert len(batch) == 1 assert len(batch) == 1
(xs, ys), utts = batch[0] data, utts = batch[0]
assert xs[0] is not None, "please check Reader and Augmentation impl." xs_data, ys_data = [], []
for ud in data:
# perform subsampling if ud[0].ndim > 1:
if self.subsampling_factor > 1: # speech data (input): (speech_len, feat_dim)
xs = [x[::self.subsampling_factor, :] for x in xs] xs_data.append(ud)
else:
# get batch of lengths of input sequences # text data (output): (text_len, )
ilens = np.array([x.shape[0] for x in xs]) ys_data.append(ud)
# perform padding and convert to tensor assert xs_data[0][0] is not None, "please check Reader and Augmentation impl."
# currently only support real number
if xs[0].dtype.kind == "c": xs_pad, ilens = [], []
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype) for xs in xs_data:
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype) # perform subsampling
# Note(kamo): if self.subsampling_factor > 1:
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. xs = [x[::self.subsampling_factor, :] for x in xs]
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it. # get batch of lengths of input sequences
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} ilens.append(np.array([x.shape[0] for x in xs]))
else:
xs_pad = pad_list(xs, 0).astype(self.dtype) # perform padding and convert to tensor
# currently only support real number
xs_pad.append(pad_list(xs, 0).astype(self.dtype))
if not self.load_aux_input:
xs_pad, ilens = xs_pad[0], ilens[0]
break
# NOTE: this is for multi-output (e.g., speech translation) # NOTE: this is for multi-output (e.g., speech translation)
ys_pad = pad_list( ys_pad, olens = [], []
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id) for ys in ys_data:
ys_pad.append(pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id))
olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]))
if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0]
break
olens = np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
return utts, xs_pad, ilens, ys_pad, olens return utts, xs_pad, ilens, ys_pad, olens

@ -19,6 +19,7 @@ from typing import Text
import jsonlines import jsonlines
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.converter import CustomConverter
@ -73,6 +74,8 @@ class BatchDataLoader():
preprocess_conf=None, preprocess_conf=None,
n_iter_processes: int=1, n_iter_processes: int=1,
subsampling_factor: int=1, subsampling_factor: int=1,
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1): num_encs: int=1):
self.json_file = json_file self.json_file = json_file
self.train_mode = train_mode self.train_mode = train_mode
@ -89,6 +92,8 @@ class BatchDataLoader():
self.num_encs = num_encs self.num_encs = num_encs
self.preprocess_conf = preprocess_conf self.preprocess_conf = preprocess_conf
self.n_iter_processes = n_iter_processes self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
# read json data # read json data
with jsonlines.open(json_file, 'r') as reader: with jsonlines.open(json_file, 'r') as reader:
@ -126,21 +131,29 @@ class BatchDataLoader():
# Setup a converter # Setup a converter
if num_encs == 1: if num_encs == 1:
self.converter = CustomConverter( self.converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=np.float32) subsampling_factor=subsampling_factor,
dtype=np.float32,
load_aux_input=load_aux_input,
load_aux_output=load_aux_output)
else: else:
assert NotImplementedError("not impl CustomConverterMulEnc.") assert NotImplementedError("not impl CustomConverterMulEnc.")
# hack to make batchsize argument as 1 # hack to make batchsize argument as 1
# actual bathsize is included in a list # actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor # default collate function converts numpy array to paddle tensor
# we used an empty collate function instead which returns list # we used an empty collate function instead which returns list
self.dataset = TransformDataset(self.minibaches, self.converter, self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader) self.reader)
self.dataloader = DataLoader( self.sampler = DistributedBatchSampler(
dataset=self.dataset, dataset=self.dataset,
batch_size=1, batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False, shuffle=not self.use_sortagrad if self.train_mode else False,
)
self.dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=self.sampler,
collate_fn=batch_collate, collate_fn=batch_collate,
num_workers=self.n_iter_processes, ) num_workers=self.n_iter_processes, )

@ -68,7 +68,7 @@ class LoadInputsAndTargets():
if mode not in ["asr"]: if mode not in ["asr"]:
raise ValueError("Only asr are allowed: mode={}".format(mode)) raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None: if preprocess_conf:
self.preprocessing = Transformation(preprocess_conf) self.preprocessing = Transformation(preprocess_conf)
logger.warning( logger.warning(
"[Experimental feature] Some preprocessing will be done " "[Experimental feature] Some preprocessing will be done "
@ -82,12 +82,11 @@ class LoadInputsAndTargets():
self.load_output = load_output self.load_output = load_output
self.load_input = load_input self.load_input = load_input
self.sort_in_input_length = sort_in_input_length self.sort_in_input_length = sort_in_input_length
if preprocess_args is None: if preprocess_args:
self.preprocess_args = {}
else:
assert isinstance(preprocess_args, dict), type(preprocess_args) assert isinstance(preprocess_args, dict), type(preprocess_args)
self.preprocess_args = dict(preprocess_args) self.preprocess_args = dict(preprocess_args)
else:
self.preprocess_args = {}
self.keep_all_data_on_mem = keep_all_data_on_mem self.keep_all_data_on_mem = keep_all_data_on_mem
def __call__(self, batch, return_uttid=False): def __call__(self, batch, return_uttid=False):

Loading…
Cancel
Save