Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import io
|
|
|
|
from typing import Optional
|
|
|
|
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
import numpy as np
|
|
|
|
from yacs.config import CfgNode
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
|
|
|
|
from paddlespeech.s2t.frontend.featurizer.speech_featurizer import SpeechFeaturizer
|
|
|
|
from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer
|
|
|
|
from paddlespeech.s2t.frontend.speech import SpeechSegment
|
|
|
|
from paddlespeech.s2t.frontend.utility import IGNORE_ID
|
|
|
|
from paddlespeech.s2t.frontend.utility import TarLocalData
|
|
|
|
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
|
|
|
|
from paddlespeech.s2t.io.utility import pad_list
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
|
|
|
__all__ = ["SpeechCollator", "TripletSpeechCollator"]
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
logger = Log(__name__).getlog()
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
|
|
|
|
|
|
|
|
def _tokenids(text, keep_transcription_text):
|
|
|
|
# for training text is token ids
|
|
|
|
tokens = text # token ids
|
|
|
|
|
|
|
|
if keep_transcription_text:
|
|
|
|
# text is string, convert to unicode ord
|
|
|
|
assert isinstance(text, str), (type(text), text)
|
|
|
|
tokens = [ord(t) for t in text]
|
|
|
|
|
|
|
|
tokens = np.array(tokens, dtype=np.int64)
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
|
|
class SpeechCollatorBase():
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
aug_file,
|
|
|
|
mean_std_filepath,
|
|
|
|
vocab_filepath,
|
|
|
|
spm_model_prefix,
|
|
|
|
random_seed=0,
|
|
|
|
unit_type="char",
|
|
|
|
spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
|
|
feat_dim=0, # 'mfcc', 'fbank'
|
|
|
|
delta_delta=False, # 'mfcc', 'fbank'
|
|
|
|
stride_ms=10.0, # ms
|
|
|
|
window_ms=20.0, # ms
|
|
|
|
n_fft=None, # fft points
|
|
|
|
max_freq=None, # None for samplerate/2
|
|
|
|
target_sample_rate=16000, # target sample rate
|
|
|
|
use_dB_normalization=True,
|
|
|
|
target_dB=-20,
|
|
|
|
dither=1.0,
|
|
|
|
keep_transcription_text=True):
|
|
|
|
"""SpeechCollator Collator
|
|
|
|
|
|
|
|
Args:
|
|
|
|
unit_type(str): token unit type, e.g. char, word, spm
|
|
|
|
vocab_filepath (str): vocab file path.
|
|
|
|
mean_std_filepath (str): mean and std file path, which suffix is *.npy
|
|
|
|
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
|
|
|
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
|
|
|
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
|
|
|
|
window_ms (float, optional): window size in ms. Defaults to 20.0.
|
|
|
|
n_fft (int, optional): fft points for rfft. Defaults to None.
|
|
|
|
max_freq (int, optional): max cut freq. Defaults to None.
|
|
|
|
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
|
|
|
|
spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
|
|
|
|
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
|
|
|
|
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
|
|
|
|
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
|
|
|
|
target_dB (int, optional): target dB. Defaults to -20.
|
|
|
|
random_seed (int, optional): for random generator. Defaults to 0.
|
|
|
|
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
|
|
|
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
|
|
|
|
|
|
|
Do augmentations
|
|
|
|
Padding audio features with zeros to make them have the same shape (or
|
|
|
|
a user-defined shape) within one batch.
|
|
|
|
"""
|
|
|
|
self.keep_transcription_text = keep_transcription_text
|
|
|
|
self.train_mode = not keep_transcription_text
|
|
|
|
|
|
|
|
self.stride_ms = stride_ms
|
|
|
|
self.window_ms = window_ms
|
|
|
|
self.feat_dim = feat_dim
|
|
|
|
|
|
|
|
self.loader = LoadInputsAndTargets()
|
|
|
|
|
|
|
|
# only for tar filetype
|
|
|
|
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
|
|
|
|
|
|
|
self.augmentation = AugmentationPipeline(
|
|
|
|
preprocess_conf=aug_file.read(), random_seed=random_seed)
|
|
|
|
|
|
|
|
self._normalizer = FeatureNormalizer(
|
|
|
|
mean_std_filepath) if mean_std_filepath else None
|
|
|
|
|
|
|
|
self._speech_featurizer = SpeechFeaturizer(
|
|
|
|
unit_type=unit_type,
|
|
|
|
vocab_filepath=vocab_filepath,
|
|
|
|
spm_model_prefix=spm_model_prefix,
|
|
|
|
spectrum_type=spectrum_type,
|
|
|
|
feat_dim=feat_dim,
|
|
|
|
delta_delta=delta_delta,
|
|
|
|
stride_ms=stride_ms,
|
|
|
|
window_ms=window_ms,
|
|
|
|
n_fft=n_fft,
|
|
|
|
max_freq=max_freq,
|
|
|
|
target_sample_rate=target_sample_rate,
|
|
|
|
use_dB_normalization=use_dB_normalization,
|
|
|
|
target_dB=target_dB,
|
|
|
|
dither=dither)
|
|
|
|
|
|
|
|
self.feature_size = self._speech_featurizer.audio_feature.feature_size
|
|
|
|
self.text_feature = self._speech_featurizer.text_feature
|
|
|
|
self.vocab_dict = self.text_feature.vocab_dict
|
|
|
|
self.vocab_list = self.text_feature.vocab_list
|
|
|
|
self.vocab_size = self.text_feature.vocab_size
|
|
|
|
|
|
|
|
def process_utterance(self, audio_file, transcript):
|
|
|
|
"""Load, augment, featurize and normalize for speech data.
|
|
|
|
|
|
|
|
:param audio_file: Filepath or file object of audio file.
|
|
|
|
:type audio_file: str | file
|
|
|
|
:param transcript: Transcription text.
|
|
|
|
:type transcript: str
|
|
|
|
:return: Tuple of audio feature tensor and data of transcription part,
|
|
|
|
where transcription part could be token ids or text.
|
|
|
|
:rtype: tuple of (2darray, list)
|
|
|
|
"""
|
|
|
|
filetype = self.loader.file_type(audio_file)
|
|
|
|
|
|
|
|
if filetype != 'sound':
|
|
|
|
spectrum = self.loader._get_from_loader(audio_file, filetype)
|
|
|
|
feat_dim = spectrum.shape[1]
|
|
|
|
assert feat_dim == self.feat_dim, f"expect feat dim {self.feat_dim}, but got {feat_dim}"
|
|
|
|
|
|
|
|
if self.keep_transcription_text:
|
|
|
|
transcript_part = transcript
|
|
|
|
else:
|
|
|
|
text_ids = self.text_feature.featurize(transcript)
|
|
|
|
transcript_part = text_ids
|
|
|
|
else:
|
|
|
|
# read audio
|
|
|
|
speech_segment = SpeechSegment.from_file(
|
|
|
|
audio_file, transcript, infos=self._local_data)
|
|
|
|
# audio augment
|
|
|
|
self.augmentation.transform_audio(speech_segment)
|
|
|
|
|
|
|
|
# extract speech feature
|
|
|
|
spectrum, transcript_part = self._speech_featurizer.featurize(
|
|
|
|
speech_segment, self.keep_transcription_text)
|
|
|
|
# CMVN spectrum
|
|
|
|
if self._normalizer:
|
|
|
|
spectrum = self._normalizer.apply(spectrum)
|
|
|
|
|
|
|
|
# spectrum augment
|
|
|
|
spectrum = self.augmentation.transform_feature(spectrum)
|
|
|
|
return spectrum, transcript_part
|
|
|
|
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
def __call__(self, batch):
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
"""batch examples
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (List[Dict]): batch is [dict(audio, text, ...)]
|
|
|
|
audio (np.ndarray) shape (T, D)
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
text (List[int] or str): shape (U,)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tuple(utts, xs_pad, ilens, ys_pad, olens): batched data.
|
|
|
|
utts: (B,)
|
|
|
|
xs_pad : (B, Tmax, D)
|
|
|
|
ilens: (B,)
|
|
|
|
ys_pad : (B, Umax)
|
|
|
|
olens: (B,)
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
"""
|
|
|
|
audios = []
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
audio_lens = []
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
texts = []
|
|
|
|
text_lens = []
|
|
|
|
utts = []
|
|
|
|
tids = [] # tokenids
|
|
|
|
|
|
|
|
for idx, item in enumerate(batch):
|
|
|
|
utts.append(item['utt'])
|
|
|
|
|
|
|
|
audio = item['input'][0]['feat']
|
|
|
|
text = item['output'][0]['text']
|
|
|
|
audio, text = self.process_utterance(audio, text)
|
|
|
|
|
|
|
|
audios.append(audio) # [T, D]
|
|
|
|
audio_lens.append(audio.shape[0])
|
|
|
|
|
|
|
|
tokens = _tokenids(text, self.keep_transcription_text)
|
E2E/Streaming Transformer/Conformer ASR (#578)
* add cmvn and label smoothing loss layer
* add layer for transformer
* add glu and conformer conv
* add torch compatiable hack, mask funcs
* not hack size since it exists
* add test; attention
* add attention, common utils, hack paddle
* add audio utils
* conformer batch padding mask bug fix #223
* fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2
* fix ci
* fix ci
* add encoder
* refactor egs
* add decoder
* refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
* refactor docs
* add fix
* fix readme
* fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
* fix docstring
* refactor data feed order
* add u2 model
* refactor cmvn, test
* add utils
* add u2 config
* fix bugs
* fix bugs
* fix autograd maybe has problem when using inplace operation
* refactor data, build vocab; add format data
* fix text featurizer
* refactor build vocab
* add fbank, refactor feature of speech
* refactor audio feat
* refactor data preprare
* refactor data
* model init from config
* add u2 bins
* flake8
* can train
* fix bugs, add coverage, add scripts
* test can run
* fix data
* speed perturb with sox
* add spec aug
* fix for train
* fix train logitc
* fix logger
* log valid loss, time dataset process
* using np for speed perturb, remove some debug log of grad clip
* fix logger
* fix build vocab
* fix logger name
* using module logger as default
* fix
* fix install
* reorder imports
* fix board logger
* fix logger
* kaldi fbank and mfcc
* fix cmvn and print prarams
* fix add_eos_sos and cmvn
* fix cmvn compute
* fix logger and cmvn
* fix subsampling, label smoothing loss, remove useless
* add notebook test
* fix log
* fix tb logger
* multi gpu valid
* fix log
* fix log
* fix config
* fix compute cmvn, need paddle 2.1
* add cmvn notebook
* fix layer tools
* fix compute cmvn
* add rtf
* fix decoding
* fix layer tools
* fix log, add avg script
* more avg and test info
* fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh;
* add vimrc
* refactor tiny script, add transformer and stream conf
* spm demo; librisppech scripts and confs
* fix log
* add librispeech scripts
* refactor data pipe; fix conf; fix u2 default params
* fix bugs
* refactor aishell scripts
* fix test
* fix cmvn
* fix s0 scripts
* fix ds2 scripts and bugs
* fix dev & test dataset filter
* fix dataset filter
* filter dev
* fix ckpt path
* filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test
* add comment
* add syllable doc
* fix ds2 configs
* add doc
* add pypinyin tools
* fix decoder using blank_id=0
* mmseg with pybind11
* format code
4 years ago
|
|
|
texts.append(tokens)
|
|
|
|
text_lens.append(tokens.shape[0])
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
|
|
|
|
#[B, T, D]
|
|
|
|
xs_pad = pad_list(audios, 0.0).astype(np.float32)
|
|
|
|
ilens = np.array(audio_lens).astype(np.int64)
|
|
|
|
ys_pad = pad_list(texts, IGNORE_ID).astype(np.int64)
|
|
|
|
olens = np.array(text_lens).astype(np.int64)
|
|
|
|
return utts, xs_pad, ilens, ys_pad, olens
|
|
|
|
|
|
|
|
|
|
|
|
class SpeechCollator(SpeechCollatorBase):
|
|
|
|
@classmethod
|
|
|
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
|
|
default = CfgNode(
|
|
|
|
dict(
|
|
|
|
augmentation_config="",
|
|
|
|
random_seed=0,
|
|
|
|
mean_std_filepath="",
|
|
|
|
unit_type="char",
|
|
|
|
vocab_filepath="",
|
|
|
|
spm_model_prefix="",
|
|
|
|
spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
|
|
feat_dim=0, # 'mfcc', 'fbank'
|
|
|
|
delta_delta=False, # 'mfcc', 'fbank'
|
|
|
|
stride_ms=10.0, # ms
|
|
|
|
window_ms=20.0, # ms
|
|
|
|
n_fft=None, # fft points
|
|
|
|
max_freq=None, # None for samplerate/2
|
|
|
|
target_sample_rate=16000, # target sample rate
|
|
|
|
use_dB_normalization=True,
|
|
|
|
target_dB=-20,
|
|
|
|
dither=1.0, # feature dither
|
|
|
|
keep_transcription_text=False))
|
|
|
|
|
|
|
|
if config is not None:
|
|
|
|
config.merge_from_other_cfg(default)
|
|
|
|
return default
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_config(cls, config):
|
|
|
|
"""Build a SpeechCollator object from a config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (yacs.config.CfgNode): configs object.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SpeechCollator: collator object.
|
|
|
|
"""
|
|
|
|
assert 'augmentation_config' in config.collator
|
|
|
|
assert 'keep_transcription_text' in config.collator
|
|
|
|
assert 'mean_std_filepath' in config.collator
|
|
|
|
assert 'vocab_filepath' in config.collator
|
|
|
|
assert 'spectrum_type' in config.collator
|
|
|
|
assert 'n_fft' in config.collator
|
|
|
|
assert config.collator
|
|
|
|
|
|
|
|
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
|
|
|
if config.collator.augmentation_config:
|
|
|
|
aug_file = io.open(
|
|
|
|
config.collator.augmentation_config,
|
|
|
|
mode='r',
|
|
|
|
encoding='utf8')
|
|
|
|
else:
|
|
|
|
aug_file = io.StringIO(initial_value='{}', newline='')
|
|
|
|
else:
|
|
|
|
aug_file = config.collator.augmentation_config
|
|
|
|
assert isinstance(aug_file, io.StringIO)
|
|
|
|
|
|
|
|
speech_collator = cls(
|
|
|
|
aug_file=aug_file,
|
|
|
|
random_seed=0,
|
|
|
|
mean_std_filepath=config.collator.mean_std_filepath,
|
|
|
|
unit_type=config.collator.unit_type,
|
|
|
|
vocab_filepath=config.collator.vocab_filepath,
|
|
|
|
spm_model_prefix=config.collator.spm_model_prefix,
|
|
|
|
spectrum_type=config.collator.spectrum_type,
|
|
|
|
feat_dim=config.collator.feat_dim,
|
|
|
|
delta_delta=config.collator.delta_delta,
|
|
|
|
stride_ms=config.collator.stride_ms,
|
|
|
|
window_ms=config.collator.window_ms,
|
|
|
|
n_fft=config.collator.n_fft,
|
|
|
|
max_freq=config.collator.max_freq,
|
|
|
|
target_sample_rate=config.collator.target_sample_rate,
|
|
|
|
use_dB_normalization=config.collator.use_dB_normalization,
|
|
|
|
target_dB=config.collator.target_dB,
|
|
|
|
dither=config.collator.dither,
|
|
|
|
keep_transcription_text=config.collator.keep_transcription_text)
|
|
|
|
return speech_collator
|
|
|
|
|
|
|
|
|
|
|
|
class TripletSpeechCollator(SpeechCollator):
|
|
|
|
def process_utterance(self, audio_file, translation, transcript):
|
|
|
|
"""Load, augment, featurize and normalize for speech data.
|
|
|
|
|
|
|
|
:param audio_file: Filepath or file object of audio file.
|
|
|
|
:type audio_file: str | file
|
|
|
|
:param translation: translation text.
|
|
|
|
:type translation: str
|
|
|
|
:return: Tuple of audio feature tensor and data of translation part,
|
|
|
|
where translation part could be token ids or text.
|
|
|
|
:rtype: tuple of (2darray, list)
|
|
|
|
"""
|
|
|
|
spectrum, translation_part = super().process_utterance(audio_file,
|
|
|
|
translation)
|
|
|
|
transcript_part = self._speech_featurizer.text_featurize(
|
|
|
|
transcript, self.keep_transcription_text)
|
|
|
|
return spectrum, translation_part, transcript_part
|
|
|
|
|
|
|
|
def __call__(self, batch):
|
|
|
|
"""batch examples
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (List[Dict]): batch is [dict(audio, text, ...)]
|
|
|
|
audio (np.ndarray) shape (T, D)
|
|
|
|
text (List[int] or str): shape (U,)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tuple(utts, xs_pad, ilens, ys_pad, olens): batched data.
|
|
|
|
utts: (B,)
|
|
|
|
xs_pad : (B, Tmax, D)
|
|
|
|
ilens: (B,)
|
|
|
|
ys_pad : [(B, Umax), (B, Umax)]
|
|
|
|
olens: [(B,), (B,)]
|
|
|
|
"""
|
|
|
|
utts = []
|
|
|
|
audios = []
|
|
|
|
audio_lens = []
|
|
|
|
translation_text = []
|
|
|
|
translation_text_lens = []
|
|
|
|
transcription_text = []
|
|
|
|
transcription_text_lens = []
|
|
|
|
|
|
|
|
for idx, item in enumerate(batch):
|
|
|
|
utts.append(item['utt'])
|
|
|
|
|
|
|
|
audio = item['input'][0]['feat']
|
|
|
|
translation = item['output'][0]['text']
|
|
|
|
transcription = item['output'][1]['text']
|
|
|
|
|
|
|
|
audio, translation, transcription = self.process_utterance(
|
|
|
|
audio, translation, transcription)
|
|
|
|
|
|
|
|
audios.append(audio) # [T, D]
|
|
|
|
audio_lens.append(audio.shape[0])
|
|
|
|
|
|
|
|
tokens = [[], []]
|
|
|
|
for idx, text in enumerate([translation, transcription]):
|
|
|
|
tokens[idx] = _tokenids(text, self.keep_transcription_text)
|
|
|
|
|
|
|
|
translation_text.append(tokens[0])
|
|
|
|
translation_text_lens.append(tokens[0].shape[0])
|
|
|
|
transcription_text.append(tokens[1])
|
|
|
|
transcription_text_lens.append(tokens[1].shape[0])
|
|
|
|
|
|
|
|
xs_pad = pad_list(audios, 0.0).astype(np.float32) #[B, T, D]
|
|
|
|
ilens = np.array(audio_lens).astype(np.int64)
|
|
|
|
|
|
|
|
padded_translation = pad_list(translation_text,
|
|
|
|
IGNORE_ID).astype(np.int64)
|
|
|
|
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
|
|
|
|
|
|
|
padded_transcription = pad_list(transcription_text,
|
|
|
|
IGNORE_ID).astype(np.int64)
|
|
|
|
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
|
|
|
|
|
|
|
ys_pad = (padded_translation, padded_transcription)
|
|
|
|
olens = (translation_lens, transcription_lens)
|
|
|
|
return utts, xs_pad, ilens, ys_pad, olens
|