wav2vec2_en, test=asr

pull/2637/head
tianhao zhang 3 years ago
parent fd73a184e7
commit a0e862d288

@ -22,7 +22,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link |
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: |
[Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - |
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 718 MB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions

@ -9,6 +9,9 @@ dnn_neurons: 1024
blank_id: 0
ctc_dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
speech_augment:
sample_rate: 16000
speeds: [95, 100, 105]
############################################
# Wav2Vec2.0 #
@ -70,7 +73,6 @@ train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
###########################################
# Dataloader #
###########################################
@ -115,6 +117,3 @@ log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
augment: True

@ -13,18 +13,13 @@
# limitations under the License.
__all__ = [
'asr_dynamic_pretrained_models',
'asr_static_pretrained_models',
'asr_onnx_pretrained_models',
'cls_dynamic_pretrained_models',
'cls_static_pretrained_models',
'st_dynamic_pretrained_models',
'st_kaldi_bins',
'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models',
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
'asr_dynamic_pretrained_models', 'asr_static_pretrained_models',
'asr_onnx_pretrained_models', 'cls_dynamic_pretrained_models',
'cls_static_pretrained_models', 'st_dynamic_pretrained_models',
'st_kaldi_bins', 'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models', 'tts_static_pretrained_models',
'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models',
'ssl_pretrained_models'
]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
@ -32,6 +27,28 @@ __all__ = [
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
# ---------------------------------
# -------------- SSL --------------
# ---------------------------------
ssl_pretrained_models = {
"wav2vec2ASR_librispeech-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz',
'md5':
'7d9449a8103ec4b17d6a004e928e0b1f',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------

@ -28,6 +28,7 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugmentConfig
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
@ -71,6 +72,7 @@ class Wav2Vec2ASRTrainer(Trainer):
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0]
if hasattr(train_conf, 'speech_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad`
@ -277,7 +279,11 @@ class Wav2Vec2ASRTrainer(Trainer):
logger.info("Setup model!")
# setup speech augmentation for wav2vec2
self.speech_augmentation = TimeDomainSpecAugment()
if hasattr(config, 'speech_augment'):
speechaugment_config = TimeDomainSpecAugmentConfig(
config.speech_augment)
self.speech_augmentation = TimeDomainSpecAugment(
speechaugment_config)
if not self.train:
return

@ -641,15 +641,56 @@ class DropChunk(nn.Layer):
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
---------
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Arguments
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(self, config):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=config.perturb_prob,
orig_freq=config.sample_rate,
speeds=config.speeds)
self.drop_freq = DropFreq(
drop_prob=config.drop_freq_prob,
drop_count_low=config.drop_freq_count_low,
drop_count_high=config.drop_freq_count_high)
self.drop_chunk = DropChunk(
drop_prob=config.drop_chunk_prob,
drop_count_low=config.drop_chunk_count_low,
drop_count_high=config.drop_chunk_count_high,
drop_length_low=config.drop_chunk_length_low,
drop_length_high=config.drop_chunk_length_high,
noise_factor=config.drop_chunk_noise_factor)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms
class TimeDomainSpecAugmentConfig():
"""Augmentation configuration for time domain spectrograms.
---------
perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied.
@ -677,56 +718,26 @@ class TimeDomainSpecAugment(nn.Layer):
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, ):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high, )
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, )
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms
def __init__(self, config):
# speedperturb config
self.perturb_prob = getattr(config, 'perturb_prob', 1.0)
self.sample_rate = getattr(config, 'sample_rate', 16000)
self.speeds = getattr(config, 'speeds', [95, 100, 105])
# dropfreq config
self.drop_freq_prob = getattr(config, 'drop_freq_prob', 1.0)
self.drop_freq_count_low = getattr(config, 'drop_freq_count_low', 0)
self.drop_freq_count_high = getattr(config, 'drop_freq_count_high', 3)
# dropchunk config
self.drop_chunk_prob = getattr(config, 'drop_chunk_prob', 1.0)
self.drop_chunk_count_low = getattr(config, 'drop_chunk_count_low', 0)
self.drop_chunk_count_high = getattr(config, 'drop_chunk_count_high', 5)
self.drop_chunk_length_low = getattr(config, 'drop_chunk_length_low',
1000)
self.drop_chunk_length_high = getattr(config, 'drop_chunk_length_high',
2000)
self.drop_chunk_noise_factor = getattr(config,
'drop_chunk_noise_factor', 0)

Loading…
Cancel
Save