Merge pull request #1074 from Jackwaterveg/fix_cli

[CLI] remove the os.chdir
pull/1077/head
Hui Zhang 3 years ago committed by GitHub
commit 664ee6167e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import os
import sys
from typing import List
@ -19,10 +20,11 @@ from typing import Optional
from typing import Union
import librosa
import numpy as np
import paddle
import soundfile
import yaml
from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor
from ..utils import cli_register
@ -46,6 +48,16 @@ pretrained_models = {
'conf/conformer.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"transformer_zh_16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5':
'4e8b63800c71040b9390b150e2a5d4c4',
'cfg_path':
'conf/transformer.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_20',
}
}
@ -121,8 +133,7 @@ class ASRExecutor(BaseExecutor):
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None
):
ckpt_path: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
"""
@ -130,10 +141,11 @@ class ASRExecutor(BaseExecutor):
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]['ckpt_path'] + ".pdparams")
self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
@ -147,10 +159,8 @@ class ASRExecutor(BaseExecutor):
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring"
model_conf = self.config.model
logger.info(model_conf)
with UpdateConfig(model_conf):
with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join(
@ -162,24 +172,29 @@ class ASRExecutor(BaseExecutor):
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = text_feature.vocab_size
self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim
model_conf.output_dim = text_feature.vocab_size
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = text_feature.vocab_size
else:
raise Exception("wrong type")
self.config.freeze()
# Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias)
model_conf = self.config.model
logger.info(model_conf)
model = model_class.from_config(model_conf)
self.model = model
self.model.eval()
@ -212,10 +227,17 @@ class ASRExecutor(BaseExecutor):
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf")
preprocess_conf = os.path.join(
os.path.dirname(os.path.abspath(self.cfg_path)),
"preprocess.yaml")
preprocess_conf_file = self.config.collator.augmentation_config
# redirect the cmvn path
with io.open(preprocess_conf_file, encoding="utf-8") as f:
preprocess_conf = yaml.safe_load(f)
for idx, process in enumerate(preprocess_conf["process"]):
if process['type'] == "cmvn_json":
preprocess_conf["process"][idx][
"cmvn_path"] = os.path.join(
self.res_path,
preprocess_conf["process"][idx]["cmvn_path"])
break
logger.info(preprocess_conf)
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
@ -310,14 +332,14 @@ class ASRExecutor(BaseExecutor):
return self._outputs["result"]
def _pcm16to32(self, audio):
assert(audio.dtype == np.int16)
assert (audio.dtype == np.int16)
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
return audio
def _pcm32to16(self, audio):
assert(audio.dtype == np.float32)
assert (audio.dtype == np.float32)
bits = np.iinfo(np.int16).bits
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
@ -326,9 +348,7 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"please input --sr 8000 or --sr 16000"
)
logger.error("please input --sr 8000 or --sr 16000")
raise Exception("invalid sample rate")
sys.exit(-1)
@ -354,13 +374,11 @@ class ASRExecutor(BaseExecutor):
sys.exit(-1)
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning(
"The sample rate of the input file is not {}.\n \
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
"
.format(self.sample_rate, self.sample_rate))
".format(self.sample_rate, self.sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
@ -398,16 +416,16 @@ class ASRExecutor(BaseExecutor):
device = parser_args.device
try:
res = self(model, lang, sample_rate, config, ckpt_path,
audio_file, device)
res = self(model, lang, sample_rate, config, ckpt_path, audio_file,
device)
logger.info('ASR Result: {}'.format(res))
return True
except Exception as e:
print(e)
return False
def __call__(self, model, lang, sample_rate, config, ckpt_path,
audio_file, device):
def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file,
device):
"""
Python API to call an executor.
"""

Loading…
Cancel
Save