Update asr inference in paddlespeech.cli.

pull/1048/head
KP 4 years ago
parent d28888972f
commit e9798498d6

@ -14,7 +14,6 @@
import os
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Union
import paddle
@ -30,16 +29,16 @@ class BaseExecutor(ABC):
self.output = None
@abstractmethod
def _get_default_cfg_path(self):
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Returns a default config file path of current task.
Download and returns pretrained resources path of current task.
"""
pass
@abstractmethod
def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None):
def _init_from_path(self, *args, **kwargs):
"""
Init model from a specific config file.
Init model and other resources from a specific path.
"""
pass

@ -21,9 +21,21 @@ import paddle
from ..executor import BaseExecutor
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
__all__ = ['S2TExecutor']
pretrained_models = {
"wenetspeech_zh": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5':
'54e7a558a6e020c2f5fb224874943f97',
}
}
@cli_register(
name='paddlespeech.s2t', description='Speech to text infer command.')
@ -33,11 +45,23 @@ class S2TExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.s2t', add_help=True)
self.parser.add_argument(
'--model',
type=str,
default='wenetspeech',
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang', type=str, default='zh', help='Choose model language.')
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of s2t task. Use deault config when it is None.')
self.parser.add_argument(
'--ckpt_path',
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument(
'--input', type=str, help='Audio file to recognize.')
self.parser.add_argument(
@ -46,16 +70,39 @@ class S2TExecutor(BaseExecutor):
default='cpu',
help='Choose device to execute model inference.')
def _get_default_cfg_path(self):
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Returns a default config file path of current task.
Download and returns pretrained resources path of current task.
"""
pass
assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
tag)
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None):
def _init_from_path(self,
model_type: str='wenetspeech',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
"""
Init model from a specific config file.
Init model and other resources from a specific path.
"""
if cfg_path is None or ckpt_path is None:
res_path = self._get_pretrained_path(
model_type + '_' + lang) # wenetspeech_zh
cfg_path = os.path.join(res_path, 'conf/conformer.yaml')
ckpt_path = os.path.join(
res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams')
logger.info(res_path)
logger.info(cfg_path)
logger.info(ckpt_path)
# Init body.
pass
def preprocess(self, input: Union[str, os.PathLike]):
@ -82,17 +129,15 @@ class S2TExecutor(BaseExecutor):
parser_args = self.parser.parse_args(argv)
print(parser_args)
model = parser_args.model
lang = parser_args.lang
config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input
device = parser_args.device
if config is not None:
assert os.path.isfile(config), 'Config file is not valid.'
else:
config = self._get_default_cfg_path()
try:
self._init_from_cfg(config)
self._init_from_path(model, lang, config, ckpt_path)
self.preprocess(audio_file)
self.infer()
res = self.postprocess() # Retrieve result of s2t.

@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
from typing import Any
from typing import Dict
from typing import List
from paddle.framework import load
from paddle.utils import download
@ -26,6 +27,7 @@ __all__ = [
'get_command',
'download_and_decompress',
'load_state_dict_from_url',
'logger',
]
@ -53,29 +55,27 @@ def get_command(name: str) -> Any:
return com['_entry']
def decompress(file: str):
def decompress(file: str) -> os.PathLike:
"""
Extracts all files from a compressed file.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
download._decompress(file)
return download._decompress(file)
def download_and_decompress(archives: List[Dict[str, str]], path: str):
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
return download.get_path_from_url(archive['url'], path, archive['md5'])
download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str=None):
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
"""
Download and load a state dict from url
"""
@ -84,3 +84,69 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None):
download.get_path_from_url(url, path, md5)
return load(os.path.join(path, os.path.basename(url)))
def _get_user_home():
return os.path.expanduser('~')
def _get_paddlespcceh_home():
if 'PPSPEECH_HOME' in os.environ:
home_path = os.environ['PPSPEECH_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPSPEECH_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddlespeech')
def _get_sub_home(directory):
home = os.path.join(_get_paddlespcceh_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')
class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleSpeech' if not name else name
self.logger = logging.getLogger(name)
log_config = {
'DEBUG': 10,
'INFO': 20,
'TRAIN': 21,
'EVAL': 22,
'WARNING': 30,
'ERROR': 40,
'CRITICAL': 50
}
for key, level in log_config.items():
logging.addLevelName(level, key)
self.__dict__[key.lower()] = functools.partial(self.__call__, level)
self.format = logging.Formatter(
fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s'
)
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):
self.logger.log(log_level, msg)
logger = Logger()

Loading…
Cancel
Save