|
|
|
@ -517,11 +517,11 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
|
#ctcdecoder need probs, not log_probs
|
|
|
|
|
probs = F.softmax(logits)
|
|
|
|
|
|
|
|
|
|
return logits, probs
|
|
|
|
|
return logits, probs, audio_len
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def infer(self, audio, audio_len):
|
|
|
|
|
_, probs = self.predict(audio, audio_len)
|
|
|
|
|
_, probs, audio_len = self.predict(audio, audio_len)
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
def forward(self, audio, text, audio_len, text_len):
|
|
|
|
@ -531,8 +531,8 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
|
audio_len: shape [B]
|
|
|
|
|
text_len: shape [B]
|
|
|
|
|
"""
|
|
|
|
|
logits, _ = self.predict(audio, audio_len)
|
|
|
|
|
return logits
|
|
|
|
|
logits, _, audio_len = self.predict(audio, audio_len)
|
|
|
|
|
return logits, audio_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Loss(nn.Layer):
|
|
|
|
|