hubert decode

pull/3088/head
th.zhang 3 years ago
parent ce75f8eb19
commit 0a8d95c6b6

@ -31,7 +31,7 @@ python3 utils/format_rsl.py \
for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=16
batch_size=8
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \

@ -7,7 +7,7 @@ MODEL=hubert
gpus=2
stage=1
stop_stage=1
stop_stage=3
conf_path=conf/${MODEL}ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
@ -20,7 +20,7 @@ audio_file=data/demo_002_en.wav
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=test3
ckpt=test6
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then

@ -186,8 +186,8 @@ class HubertASRTrainer(Trainer):
wavs_lens_rate = wavs_lens / wav.shape[1]
wav = wav[:, :, 0]
# if hasattr(train_conf, 'audio_augment'):
# wav = self.speech_augmentation(wav, wavs_lens_rate)
if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens)

@ -86,7 +86,6 @@ class HubertASR(nn.Layer):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)
self.hubert.eval()
# Extract wav2vec output
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
@ -205,7 +204,7 @@ class HubertASR(nn.Layer):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
@ -247,7 +246,7 @@ class HubertASR(nn.Layer):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
@ -323,14 +322,17 @@ class HubertASR(nn.Layer):
return hyps[0][0]
class Wav2vec2Base(nn.Layer):
class HubertBase(nn.Layer):
"""Wav2vec2 model"""
def __init__(self, config: dict):
super().__init__()
wav2vec2_config = Wav2Vec2ConfigPure(config)
wav2vec2 = Wav2Vec2Model(wav2vec2_config)
self.wav2vec2 = wav2vec2
with open(config.vocab_filepath) as f:
dicts = [symbol.strip() for symbol in f.readlines()]
task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg))
model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg))
hubert = HubertModel(model_cfg, task_cfg, dicts)
self.hubert= hubert
@classmethod
def from_config(cls, configs: dict):
@ -340,11 +342,11 @@ class Wav2vec2Base(nn.Layer):
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: Wav2Vec2Base
nn.Layer: HubertBase
"""
model = cls(configs)
return model
def forward(self, wav):
out = self.wav2vec2(wav)
out = self.hubert(wav)
return out

Loading…
Cancel
Save