|
|
@ -19,12 +19,13 @@ from pathlib import Path
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
import soundfile
|
|
|
|
import soundfile
|
|
|
|
from paddlenlp.transformers import AutoTokenizer
|
|
|
|
from paddlenlp.transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
|
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
|
|
|
|
from paddlespeech.s2t.training.cli import default_argument_parser
|
|
|
|
from paddlespeech.s2t.training.cli import default_argument_parser
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -33,8 +34,9 @@ class Wav2vec2Infer():
|
|
|
|
self.args = args
|
|
|
|
self.args = args
|
|
|
|
self.config = config
|
|
|
|
self.config = config
|
|
|
|
self.audio_file = args.audio_file
|
|
|
|
self.audio_file = args.audio_file
|
|
|
|
|
|
|
|
self.tokenizer = config.get("tokenizer", None)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.tokenizer:
|
|
|
|
if self.tokenizer:
|
|
|
|
self.text_feature = AutoTokenizer.from_pretrained(
|
|
|
|
self.text_feature = AutoTokenizer.from_pretrained(
|
|
|
|
self.config.tokenizer)
|
|
|
|
self.config.tokenizer)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -71,7 +73,7 @@ class Wav2vec2Infer():
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
decoding_method=decode_config.decoding_method,
|
|
|
|
decoding_method=decode_config.decoding_method,
|
|
|
|
beam_size=decode_config.beam_size,
|
|
|
|
beam_size=decode_config.beam_size,
|
|
|
|
tokenizer=self.config.tokenizer, )
|
|
|
|
tokenizer=self.tokenizer, )
|
|
|
|
rsl = result_transcripts[0]
|
|
|
|
rsl = result_transcripts[0]
|
|
|
|
utt = Path(self.audio_file).name
|
|
|
|
utt = Path(self.audio_file).name
|
|
|
|
logger.info(f"hyp: {utt} {rsl}")
|
|
|
|
logger.info(f"hyp: {utt} {rsl}")
|
|
|
|