wav2vec2_en, test=asr

pull/2637/head
tianhao zhang 3 years ago
parent 0f766848c2
commit f53598b5c8

@ -9,9 +9,6 @@ dnn_neurons: 1024
blank_id: 0 blank_id: 0
ctc_dropout_rate: 0.0 ctc_dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams" wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
speech_augment:
sample_rate: 16000
speeds: [95, 100, 105]
############################################ ############################################
# Wav2Vec2.0 # # Wav2Vec2.0 #
@ -97,6 +94,12 @@ dist_sampler: True
shortest_first: True shortest_first: True
return_lens_rate: True return_lens_rate: True
############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]
########################################### ###########################################
# Training # # Training #

@ -13,13 +13,18 @@
# limitations under the License. # limitations under the License.
__all__ = [ __all__ = [
'asr_dynamic_pretrained_models', 'asr_static_pretrained_models', 'asr_dynamic_pretrained_models',
'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models', 'asr_static_pretrained_models',
'cls_static_pretrained_models', 'st_dynamic_pretrained_models', 'asr_onnx_pretrained_models',
'st_kaldi_bins', 'text_dynamic_pretrained_models', 'cls_dynamic_pretrained_models',
'tts_dynamic_pretrained_models', 'tts_static_pretrained_models', 'cls_static_pretrained_models',
'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', 'st_dynamic_pretrained_models',
'ssl_pretrained_models' 'st_kaldi_bins',
'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models',
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
] ]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
@ -27,28 +32,6 @@ __all__ = [
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
# ---------------------------------
# -------------- SSL --------------
# ---------------------------------
ssl_pretrained_models = {
"wav2vec2ASR_librispeech-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz',
'md5':
'7d9449a8103ec4b17d6a004e928e0b1f',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
# --------------------------------- # ---------------------------------
# -------------- ASR -------------- # -------------- ASR --------------
# --------------------------------- # ---------------------------------

@ -28,7 +28,6 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugmentConfig
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -279,11 +278,9 @@ class Wav2Vec2ASRTrainer(Trainer):
logger.info("Setup model!") logger.info("Setup model!")
# setup speech augmentation for wav2vec2 # setup speech augmentation for wav2vec2
if hasattr(config, 'speech_augment') and self.train: if hasattr(config, 'audio_augment') and self.train:
speechaugment_config = TimeDomainSpecAugmentConfig(
config.speech_augment)
self.speech_augmentation = TimeDomainSpecAugment( self.speech_augmentation = TimeDomainSpecAugment(
speechaugment_config) **config.audio_augment)
if not self.train: if not self.train:
return return

@ -641,56 +641,12 @@ class DropChunk(nn.Layer):
class TimeDomainSpecAugment(nn.Layer): class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm. """A time-domain approximation of the SpecAugment algorithm.
---------
This augmentation module implements three augmentations in This augmentation module implements three augmentations in
the time-domain. the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise) 1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters) 2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate) 3. Speed peturbation (via resampling to slightly different rate)
Arguments
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(self, config):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=config.perturb_prob,
orig_freq=config.sample_rate,
speeds=config.speeds)
self.drop_freq = DropFreq(
drop_prob=config.drop_freq_prob,
drop_count_low=config.drop_freq_count_low,
drop_count_high=config.drop_freq_count_high)
self.drop_chunk = DropChunk(
drop_prob=config.drop_chunk_prob,
drop_count_low=config.drop_chunk_count_low,
drop_count_high=config.drop_chunk_count_high,
drop_length_low=config.drop_chunk_length_low,
drop_length_high=config.drop_chunk_length_high,
noise_factor=config.drop_chunk_noise_factor)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms
class TimeDomainSpecAugmentConfig():
"""Augmentation configuration for time domain spectrograms.
--------- ---------
perturb_prob : float from 0 to 1 perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied. The probability that a batch will have speed perturbation applied.
@ -718,26 +674,54 @@ class TimeDomainSpecAugmentConfig():
drop_chunk_noise_factor : float drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted). the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
""" """
def __init__(self, config): def __init__(
# speedperturb config self,
self.perturb_prob = getattr(config, 'perturb_prob', 1.0) perturb_prob=1.0,
self.sample_rate = getattr(config, 'sample_rate', 16000) drop_freq_prob=1.0,
self.speeds = getattr(config, 'speeds', [95, 100, 105]) drop_chunk_prob=1.0,
speeds=[95, 100, 105],
# dropfreq config sample_rate=16000,
self.drop_freq_prob = getattr(config, 'drop_freq_prob', 1.0) drop_freq_count_low=0,
self.drop_freq_count_low = getattr(config, 'drop_freq_count_low', 0) drop_freq_count_high=3,
self.drop_freq_count_high = getattr(config, 'drop_freq_count_high', 3) drop_chunk_count_low=0,
drop_chunk_count_high=5,
# dropchunk config drop_chunk_length_low=1000,
self.drop_chunk_prob = getattr(config, 'drop_chunk_prob', 1.0) drop_chunk_length_high=2000,
self.drop_chunk_count_low = getattr(config, 'drop_chunk_count_low', 0) drop_chunk_noise_factor=0, ):
self.drop_chunk_count_high = getattr(config, 'drop_chunk_count_high', 5) super().__init__()
self.drop_chunk_length_low = getattr(config, 'drop_chunk_length_low', self.speed_perturb = SpeedPerturb(
1000) perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
self.drop_chunk_length_high = getattr(config, 'drop_chunk_length_high', self.drop_freq = DropFreq(
2000) drop_prob=drop_freq_prob,
self.drop_chunk_noise_factor = getattr(config, drop_count_low=drop_freq_count_low,
'drop_chunk_noise_factor', 0) drop_count_high=drop_freq_count_high, )
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, )
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms

Loading…
Cancel
Save