rm the os.chdir in cli asr

pull/1074/head
huangyuxin 3 years ago
parent 021311c76b
commit 1b57d05d1b

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
from typing import List from typing import List
@ -19,10 +20,11 @@ from typing import Optional
from typing import Union from typing import Union
import librosa import librosa
import numpy as np
import paddle import paddle
import soundfile import soundfile
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
@ -131,8 +133,7 @@ class ASRExecutor(BaseExecutor):
lang: str='zh', lang: str='zh',
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None ckpt_path: Optional[os.PathLike]=None):
):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
@ -140,10 +141,11 @@ class ASRExecutor(BaseExecutor):
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(
pretrained_models[tag]['ckpt_path'] + ".pdparams") res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
@ -157,10 +159,8 @@ class ASRExecutor(BaseExecutor):
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
model_conf = self.config.model
logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
@ -172,24 +172,29 @@ class ASRExecutor(BaseExecutor):
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
self.config.freeze()
# Enter the path of model root # Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias) model_class = dynamic_import(model_type, model_alias)
model_conf = self.config.model
logger.info(model_conf)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval() self.model.eval()
@ -222,10 +227,17 @@ class ASRExecutor(BaseExecutor):
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = os.path.join( preprocess_conf_file = self.config.collator.augmentation_config
os.path.dirname(os.path.abspath(self.cfg_path)), # redirect the cmvn path
"preprocess.yaml") with io.open(preprocess_conf_file, encoding="utf-8") as f:
preprocess_conf = yaml.safe_load(f)
for idx, process in enumerate(preprocess_conf["process"]):
if process['type'] == "cmvn_json":
preprocess_conf["process"][idx][
"cmvn_path"] = os.path.join(
self.res_path,
preprocess_conf["process"][idx]["cmvn_path"])
break
logger.info(preprocess_conf) logger.info(preprocess_conf)
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
@ -336,9 +348,7 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error("please input --sr 8000 or --sr 16000")
"please input --sr 8000 or --sr 16000"
)
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
@ -364,13 +374,11 @@ class ASRExecutor(BaseExecutor):
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % audio_sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning( logger.warning("The sample rate of the input file is not {}.\n \
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
" ".format(self.sample_rate, self.sample_rate))
.format(self.sample_rate, self.sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
@ -408,16 +416,16 @@ class ASRExecutor(BaseExecutor):
device = parser_args.device device = parser_args.device
try: try:
res = self(model, lang, sample_rate, config, ckpt_path, res = self(model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device) device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, sample_rate, config, ckpt_path, def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device): device):
""" """
Python API to call an executor. Python API to call an executor.
""" """

Loading…
Cancel
Save