|
|
@ -12,6 +12,7 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
from turtle import Turtle
|
|
|
|
from typing import Dict
|
|
|
|
from typing import Dict
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
from typing import Tuple
|
|
|
|
from typing import Tuple
|
|
|
@ -83,6 +84,7 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
text_feature: Dict[str, int],
|
|
|
|
text_feature: Dict[str, int],
|
|
|
|
decoding_method: str,
|
|
|
|
decoding_method: str,
|
|
|
|
beam_size: int,
|
|
|
|
beam_size: int,
|
|
|
|
|
|
|
|
tokenizer: str=None,
|
|
|
|
sb_pipeline=False):
|
|
|
|
sb_pipeline=False):
|
|
|
|
batch_size = feats.shape[0]
|
|
|
|
batch_size = feats.shape[0]
|
|
|
|
|
|
|
|
|
|
|
@ -93,12 +95,15 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
logger.error(f"current batch_size is {batch_size}")
|
|
|
|
logger.error(f"current batch_size is {batch_size}")
|
|
|
|
|
|
|
|
|
|
|
|
if decoding_method == 'ctc_greedy_search':
|
|
|
|
if decoding_method == 'ctc_greedy_search':
|
|
|
|
if not sb_pipeline:
|
|
|
|
if tokenizer is None and sb_pipeline is False:
|
|
|
|
hyps = self.ctc_greedy_search(feats)
|
|
|
|
hyps = self.ctc_greedy_search(feats)
|
|
|
|
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
|
|
|
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
|
|
|
res_tokenids = [hyp for hyp in hyps]
|
|
|
|
res_tokenids = [hyp for hyp in hyps]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
if sb_pipeline is True:
|
|
|
|
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
|
|
|
|
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
hyps = self.ctc_greedy_search(feats)
|
|
|
|
res = []
|
|
|
|
res = []
|
|
|
|
res_tokenids = []
|
|
|
|
res_tokenids = []
|
|
|
|
for sequence in hyps:
|
|
|
|
for sequence in hyps:
|
|
|
@ -123,13 +128,16 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
# with other batch decoding mode
|
|
|
|
# with other batch decoding mode
|
|
|
|
elif decoding_method == 'ctc_prefix_beam_search':
|
|
|
|
elif decoding_method == 'ctc_prefix_beam_search':
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
assert feats.shape[0] == 1
|
|
|
|
if not sb_pipeline:
|
|
|
|
if tokenizer is None and sb_pipeline is False:
|
|
|
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
|
|
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
|
|
|
res = [text_feature.defeaturize(hyp)]
|
|
|
|
res = [text_feature.defeaturize(hyp)]
|
|
|
|
res_tokenids = [hyp]
|
|
|
|
res_tokenids = [hyp]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
if sb_pipeline is True:
|
|
|
|
hyp = self.ctc_prefix_beam_search(
|
|
|
|
hyp = self.ctc_prefix_beam_search(
|
|
|
|
feats.unsqueeze(-1), beam_size)
|
|
|
|
feats.unsqueeze(-1), beam_size)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
|
|
|
res = []
|
|
|
|
res = []
|
|
|
|
res_tokenids = []
|
|
|
|
res_tokenids = []
|
|
|
|
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
|
|
|
|
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
|
|
|
|