pull/2324/head
tianhao zhang 3 years ago
parent 80fc0ef71a
commit c314f8b769

@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle.Tensor.new_full = new_full paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full paddle.static.Variable.new_full = new_full
def contiguous(xs: paddle.Tensor) -> paddle.Tensor: def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs return xs

@ -26,8 +26,8 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -109,7 +109,8 @@ class U2Trainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -136,7 +137,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -157,7 +159,8 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -225,14 +228,18 @@ class U2Trainer(Trainer):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) self.train_loader = DataLoaderFactory.get_dataloader(
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) 'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):

@ -105,7 +105,8 @@ class U2Trainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -133,7 +134,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -153,7 +155,8 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -165,8 +168,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch : {}/{}, ".format(batch_index + 1, msg += "batch : {}/{}, ".format(
len(self.train_loader)) batch_index + 1, len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time) msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
@ -204,21 +207,24 @@ class U2Trainer(Trainer):
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
config = self.config.clone() config = self.config.clone()
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config

@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
config['load_transcript'] = load_transcript config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) self.train_loader = DataLoaderFactory.get_dataloader(
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) 'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config model_conf = config

@ -22,17 +22,16 @@ import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.converter import CustomConverter
from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from yacs.config import CfgNode
__all__ = ["BatchDataLoader", "StreamDataLoader"] __all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -61,6 +60,7 @@ def batch_collate(x):
""" """
return x[0] return x[0]
def read_preprocess_cfg(preprocess_conf_file): def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict() augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True) preprocess_cfg = CfgNode(new_allowed=True)
@ -84,6 +84,7 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf['t_replace_with_zero'] = process['replace_with_zero'] augment_conf['t_replace_with_zero'] = process['replace_with_zero']
return augment_conf return augment_conf
class StreamDataLoader(): class StreamDataLoader():
def __init__(self, def __init__(self,
manifest_file: str, manifest_file: str,
@ -131,10 +132,14 @@ class StreamDataLoader():
world_size = paddle.distributed.get_world_size() world_size = paddle.distributed.get_world_size()
except Exception as e: except Exception as e:
logger.warninig(e) logger.warninig(e)
logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") logger.warninig(
assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") "can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert (len(shardlist) >= world_size,
"the length of shard list should >= number of gpus/xpus/...")
update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) update_n_iter_processes = int(
max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}") logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes: if update_n_iter_processes != self.n_iter_processes:
self.n_iter_processes = update_n_iter_processes self.n_iter_processes = update_n_iter_processes
@ -142,44 +147,50 @@ class StreamDataLoader():
if self.dist_sampler: if self.dist_sampler:
base_dataset = streamdata.DataPipeline( base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist), streamdata.SimpleShardList(shardlist), streamdata.split_by_node
streamdata.split_by_node if train_mode else streamdata.placeholder(), if train_mode else streamdata.placeholder(),
streamdata.split_by_worker, streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception) streamdata.tarfile_to_samples(streamdata.reraise_exception))
)
else: else:
base_dataset = streamdata.DataPipeline( base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist), streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker, streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception) streamdata.tarfile_to_samples(streamdata.reraise_exception))
)
self.dataset = base_dataset.append_list( self.dataset = base_dataset.append_list(
streamdata.audio_tokenize(symbol_table), streamdata.audio_tokenize(symbol_table),
streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), streamdata.audio_data_filter(
frame_shift=frame_shift,
max_length=maxlen_in,
min_length=minlen_in,
token_max_length=maxlen_out,
token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate), streamdata.audio_resample(resample_rate=resample_rate),
streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), streamdata.audio_compute_fbank(
streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither),
streamdata.audio_spec_aug(**augment_conf)
if train_mode else streamdata.placeholder(
), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size), streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size), streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size), streamdata.batched(batch_size),
streamdata.audio_padding(), streamdata.audio_padding(),
streamdata.audio_cmvn(cmvn_file) streamdata.audio_cmvn(cmvn_file))
)
if paddle.__version__ >= '2.3.2': if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader( self.loader = streamdata.WebLoader(
self.dataset, self.dataset,
num_workers=self.n_iter_processes, num_workers=self.n_iter_processes,
prefetch_factor=self.prefetch_factor, prefetch_factor=self.prefetch_factor,
batch_size=None batch_size=None)
)
else: else:
self.loader = streamdata.WebLoader( self.loader = streamdata.WebLoader(
self.dataset, self.dataset,
num_workers=self.n_iter_processes, num_workers=self.n_iter_processes,
batch_size=None batch_size=None)
)
def __iter__(self): def __iter__(self):
return self.loader.__iter__() return self.loader.__iter__()
@ -188,7 +199,9 @@ class StreamDataLoader():
return self.__iter__() return self.__iter__()
def __len__(self): def __len__(self):
logger.info("Stream dataloader does not support calculate the length of the dataset") logger.info(
"Stream dataloader does not support calculate the length of the dataset"
)
return -1 return -1
@ -358,7 +371,9 @@ class DataLoaderFactory():
config['maxlen_out'] = float('inf') config['maxlen_out'] = float('inf')
config['dist_sampler'] = False config['dist_sampler'] = False
else: else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") raise KeyError(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return StreamDataLoader( return StreamDataLoader(
manifest_file=config.manifest, manifest_file=config.manifest,
train_mode=config.train_mode, train_mode=config.train_mode,
@ -380,8 +395,7 @@ class DataLoaderFactory():
prefetch_factor=config.prefetch_factor, prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler, dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file, cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath, vocab_filepath=config.vocab_filepath, )
)
else: else:
if mode == 'train': if mode == 'train':
config['manifest'] = config.train_manifest config['manifest'] = config.train_manifest
@ -427,7 +441,9 @@ class DataLoaderFactory():
config['dist_sampler'] = False config['dist_sampler'] = False
config['shortest_first'] = False config['shortest_first'] = False
else: else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") raise KeyError(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return BatchDataLoader( return BatchDataLoader(
json_file=config.manifest, json_file=config.manifest,
@ -450,4 +466,3 @@ class DataLoaderFactory():
num_encs=config.num_encs, num_encs=config.num_encs,
dist_sampler=config.dist_sampler, dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first) shortest_first=config.shortest_first)

@ -26,6 +26,8 @@ import paddle
from paddle import jit from paddle import jit
from paddle import nn from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.cmvn import GlobalCMVN
@ -38,8 +40,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"] __all__ = ["U2STModel", "U2STInferModel"]
@ -435,8 +435,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache. same shape as the original cnn_cache.
""" """
return self.encoder.forward_chunk( return self.encoder.forward_chunk(xs, offset, required_cache_size,
xs, offset, required_cache_size, att_cache, cnn_cache) att_cache, cnn_cache)
# @jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

Loading…
Cancel
Save