From 0a8d95c6b6c3ed087661fd87a48edb4075dd73ac Mon Sep 17 00:00:00 2001 From: "th.zhang" <15600919271@163.com> Date: Fri, 24 Mar 2023 22:50:26 +0800 Subject: [PATCH] hubert decode --- examples/librispeech/asr3/local/test.sh | 2 +- examples/librispeech/asr3/run.sh | 4 ++-- paddlespeech/s2t/exps/hubert/model.py | 4 ++-- paddlespeech/s2t/models/hubert/hubert_ASR.py | 20 +++++++++++--------- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/librispeech/asr3/local/test.sh b/examples/librispeech/asr3/local/test.sh index ccc0d84de..0a5104f1c 100755 --- a/examples/librispeech/asr3/local/test.sh +++ b/examples/librispeech/asr3/local/test.sh @@ -31,7 +31,7 @@ python3 utils/format_rsl.py \ for type in ctc_greedy_search; do echo "decoding ${type}" - batch_size=16 + batch_size=8 python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh index 53b885e6d..8ebab30d0 100755 --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -7,7 +7,7 @@ MODEL=hubert gpus=2 stage=1 -stop_stage=1 +stop_stage=3 conf_path=conf/${MODEL}ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml @@ -20,7 +20,7 @@ audio_file=data/demo_002_en.wav avg_ckpt=avg_${avg_num} ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') -ckpt=test3 +ckpt=test6 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then diff --git a/paddlespeech/s2t/exps/hubert/model.py b/paddlespeech/s2t/exps/hubert/model.py index 1e5496f41..8e2ab5745 100644 --- a/paddlespeech/s2t/exps/hubert/model.py +++ b/paddlespeech/s2t/exps/hubert/model.py @@ -186,8 +186,8 @@ class HubertASRTrainer(Trainer): wavs_lens_rate = wavs_lens / wav.shape[1] wav = wav[:, :, 0] - # if hasattr(train_conf, 'audio_augment'): - # wav = self.speech_augmentation(wav, wavs_lens_rate) + if hasattr(train_conf, 'audio_augment'): + wav = self.speech_augmentation(wav, wavs_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens) diff --git a/paddlespeech/s2t/models/hubert/hubert_ASR.py b/paddlespeech/s2t/models/hubert/hubert_ASR.py index b31d10c81..2f45cd1ff 100644 --- a/paddlespeech/s2t/models/hubert/hubert_ASR.py +++ b/paddlespeech/s2t/models/hubert/hubert_ASR.py @@ -86,7 +86,6 @@ class HubertASR(nn.Layer): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape) - self.hubert.eval() # Extract wav2vec output out = self.hubert.extract_features(wav)[0] # We normalize the output if required @@ -205,7 +204,7 @@ class HubertASR(nn.Layer): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output - out = self.wav2vec2(wav)[0] + out = self.hubert.extract_features(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape[1:]) @@ -247,7 +246,7 @@ class HubertASR(nn.Layer): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output - out = self.wav2vec2(wav)[0] + out = self.hubert.extract_features(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape[1:]) @@ -323,14 +322,17 @@ class HubertASR(nn.Layer): return hyps[0][0] -class Wav2vec2Base(nn.Layer): +class HubertBase(nn.Layer): """Wav2vec2 model""" def __init__(self, config: dict): super().__init__() - wav2vec2_config = Wav2Vec2ConfigPure(config) - wav2vec2 = Wav2Vec2Model(wav2vec2_config) - self.wav2vec2 = wav2vec2 + with open(config.vocab_filepath) as f: + dicts = [symbol.strip() for symbol in f.readlines()] + task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg)) + model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg)) + hubert = HubertModel(model_cfg, task_cfg, dicts) + self.hubert= hubert @classmethod def from_config(cls, configs: dict): @@ -340,11 +342,11 @@ class Wav2vec2Base(nn.Layer): Raises: ValueError: raise when using not support encoder type. Returns: - nn.Layer: Wav2Vec2Base + nn.Layer: HubertBase """ model = cls(configs) return model def forward(self, wav): - out = self.wav2vec2(wav) + out = self.hubert(wav) return out