optional tokenizer and fix some doc (#3046)

pull/3057/head
zxcd 3 years ago committed by GitHub
parent 1bf1c3ab92
commit 4d1787dcf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@ set -e
gpus=0 gpus=0
stage=0 stage=0
stop_stage=0 stop_stage=4
conf_path=conf/wav2vec2ASR.yaml conf_path=conf/wav2vec2ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml

@ -19,12 +19,13 @@ from pathlib import Path
import paddle import paddle
import soundfile import soundfile
from paddlenlp.transformers import AutoTokenizer from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from yacs.config import CfgNode
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -33,8 +34,9 @@ class Wav2vec2Infer():
self.args = args self.args = args
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.tokenizer = config.get("tokenizer", None)
if self.config.tokenizer: if self.tokenizer:
self.text_feature = AutoTokenizer.from_pretrained( self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer) self.config.tokenizer)
else: else:
@ -71,7 +73,7 @@ class Wav2vec2Infer():
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=decode_config.decoding_method, decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size, beam_size=decode_config.beam_size,
tokenizer=self.config.tokenizer, ) tokenizer=self.tokenizer, )
rsl = result_transcripts[0] rsl = result_transcripts[0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}") logger.info(f"hyp: {utt} {rsl}")

Loading…
Cancel
Save