|
|
|
@ -86,7 +86,6 @@ class HubertASR(nn.Layer):
|
|
|
|
|
if self.normalize_wav:
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape)
|
|
|
|
|
|
|
|
|
|
self.hubert.eval()
|
|
|
|
|
# Extract wav2vec output
|
|
|
|
|
out = self.hubert.extract_features(wav)[0]
|
|
|
|
|
# We normalize the output if required
|
|
|
|
@ -205,7 +204,7 @@ class HubertASR(nn.Layer):
|
|
|
|
|
if self.normalize_wav:
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape[1:])
|
|
|
|
|
# Extract wav2vec output
|
|
|
|
|
out = self.wav2vec2(wav)[0]
|
|
|
|
|
out = self.hubert.extract_features(wav)[0]
|
|
|
|
|
# We normalize the output if required
|
|
|
|
|
if self.output_norm:
|
|
|
|
|
out = F.layer_norm(out, out.shape[1:])
|
|
|
|
@ -247,7 +246,7 @@ class HubertASR(nn.Layer):
|
|
|
|
|
if self.normalize_wav:
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape[1:])
|
|
|
|
|
# Extract wav2vec output
|
|
|
|
|
out = self.wav2vec2(wav)[0]
|
|
|
|
|
out = self.hubert.extract_features(wav)[0]
|
|
|
|
|
# We normalize the output if required
|
|
|
|
|
if self.output_norm:
|
|
|
|
|
out = F.layer_norm(out, out.shape[1:])
|
|
|
|
@ -323,14 +322,17 @@ class HubertASR(nn.Layer):
|
|
|
|
|
return hyps[0][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Wav2vec2Base(nn.Layer):
|
|
|
|
|
class HubertBase(nn.Layer):
|
|
|
|
|
"""Wav2vec2 model"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: dict):
|
|
|
|
|
super().__init__()
|
|
|
|
|
wav2vec2_config = Wav2Vec2ConfigPure(config)
|
|
|
|
|
wav2vec2 = Wav2Vec2Model(wav2vec2_config)
|
|
|
|
|
self.wav2vec2 = wav2vec2
|
|
|
|
|
with open(config.vocab_filepath) as f:
|
|
|
|
|
dicts = [symbol.strip() for symbol in f.readlines()]
|
|
|
|
|
task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg))
|
|
|
|
|
model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg))
|
|
|
|
|
hubert = HubertModel(model_cfg, task_cfg, dicts)
|
|
|
|
|
self.hubert= hubert
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_config(cls, configs: dict):
|
|
|
|
@ -340,11 +342,11 @@ class Wav2vec2Base(nn.Layer):
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: raise when using not support encoder type.
|
|
|
|
|
Returns:
|
|
|
|
|
nn.Layer: Wav2Vec2Base
|
|
|
|
|
nn.Layer: HubertBase
|
|
|
|
|
"""
|
|
|
|
|
model = cls(configs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
def forward(self, wav):
|
|
|
|
|
out = self.wav2vec2(wav)
|
|
|
|
|
out = self.hubert(wav)
|
|
|
|
|
return out
|
|
|
|
|