wav2vec2_en, test=asr

pull/2637/head
tianhao zhang 3 years ago
parent 0f766848c2
commit f53598b5c8

@ -9,9 +9,6 @@ dnn_neurons: 1024
blank_id: 0
ctc_dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
speech_augment:
sample_rate: 16000
speeds: [95, 100, 105]
############################################
# Wav2Vec2.0 #
@ -97,6 +94,12 @@ dist_sampler: True
shortest_first: True
return_lens_rate: True
############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]
###########################################
# Training #

@ -13,13 +13,18 @@
# limitations under the License.
__all__ = [
'asr_dynamic_pretrained_models', 'asr_static_pretrained_models',
'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models',
'cls_static_pretrained_models', 'st_dynamic_pretrained_models',
'st_kaldi_bins', 'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models', 'tts_static_pretrained_models',
'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models',
'ssl_pretrained_models'
'asr_dynamic_pretrained_models',
'asr_static_pretrained_models',
'asr_onnx_pretrained_models',
'cls_dynamic_pretrained_models',
'cls_static_pretrained_models',
'st_dynamic_pretrained_models',
'st_kaldi_bins',
'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models',
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
@ -27,28 +32,6 @@ __all__ = [
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
# ---------------------------------
# -------------- SSL --------------
# ---------------------------------
ssl_pretrained_models = {
"wav2vec2ASR_librispeech-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz',
'md5':
'7d9449a8103ec4b17d6a004e928e0b1f',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------

@ -28,7 +28,6 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugmentConfig
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
@ -279,11 +278,9 @@ class Wav2Vec2ASRTrainer(Trainer):
logger.info("Setup model!")
# setup speech augmentation for wav2vec2
if hasattr(config, 'speech_augment') and self.train:
speechaugment_config = TimeDomainSpecAugmentConfig(
config.speech_augment)
if hasattr(config, 'audio_augment') and self.train:
self.speech_augmentation = TimeDomainSpecAugment(
speechaugment_config)
**config.audio_augment)
if not self.train:
return

@ -641,56 +641,12 @@ class DropChunk(nn.Layer):
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
---------
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(self, config):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=config.perturb_prob,
orig_freq=config.sample_rate,
speeds=config.speeds)
self.drop_freq = DropFreq(
drop_prob=config.drop_freq_prob,
drop_count_low=config.drop_freq_count_low,
drop_count_high=config.drop_freq_count_high)
self.drop_chunk = DropChunk(
drop_prob=config.drop_chunk_prob,
drop_count_low=config.drop_chunk_count_low,
drop_count_high=config.drop_chunk_count_high,
drop_length_low=config.drop_chunk_length_low,
drop_length_high=config.drop_chunk_length_high,
noise_factor=config.drop_chunk_noise_factor)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms
class TimeDomainSpecAugmentConfig():
"""Augmentation configuration for time domain spectrograms.
Arguments
---------
perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied.
@ -718,26 +674,54 @@ class TimeDomainSpecAugmentConfig():
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(self, config):
# speedperturb config
self.perturb_prob = getattr(config, 'perturb_prob', 1.0)
self.sample_rate = getattr(config, 'sample_rate', 16000)
self.speeds = getattr(config, 'speeds', [95, 100, 105])
# dropfreq config
self.drop_freq_prob = getattr(config, 'drop_freq_prob', 1.0)
self.drop_freq_count_low = getattr(config, 'drop_freq_count_low', 0)
self.drop_freq_count_high = getattr(config, 'drop_freq_count_high', 3)
# dropchunk config
self.drop_chunk_prob = getattr(config, 'drop_chunk_prob', 1.0)
self.drop_chunk_count_low = getattr(config, 'drop_chunk_count_low', 0)
self.drop_chunk_count_high = getattr(config, 'drop_chunk_count_high', 5)
self.drop_chunk_length_low = getattr(config, 'drop_chunk_length_low',
1000)
self.drop_chunk_length_high = getattr(config, 'drop_chunk_length_high',
2000)
self.drop_chunk_noise_factor = getattr(config,
'drop_chunk_noise_factor', 0)
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, ):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high, )
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, )
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms

Loading…
Cancel
Save